--- /dev/null
+.idea*
\ No newline at end of file
--- /dev/null
+MIT License
+
+Copyright (c) 2017 HashiCorp
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
--- /dev/null
+# go-hclog
+
+[![Go Documentation](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)][godocs]
+
+[godocs]: https://godoc.org/github.com/hashicorp/go-hclog
+
+`go-hclog` is a package for Go that provides a simple key/value logging
+interface for use in development and production environments.
+
+It provides logging levels that provide decreased output based upon the
+desired amount of output, unlike the standard library `log` package.
+
+It provides `Printf` style logging of values via `hclog.Fmt()`.
+
+It provides a human readable output mode for use in development as well as
+JSON output mode for production.
+
+## Stability Note
+
+While this library is fully open source and HashiCorp will be maintaining it
+(since we are and will be making extensive use of it), the API and output
+format is subject to minor changes as we fully bake and vet it in our projects.
+This notice will be removed once it's fully integrated into our major projects
+and no further changes are anticipated.
+
+## Installation and Docs
+
+Install using `go get github.com/hashicorp/go-hclog`.
+
+Full documentation is available at
+http://godoc.org/github.com/hashicorp/go-hclog
+
+## Usage
+
+### Use the global logger
+
+```go
+hclog.Default().Info("hello world")
+```
+
+```text
+2017-07-05T16:15:55.167-0700 [INFO ] hello world
+```
+
+(Note timestamps are removed in future examples for brevity.)
+
+### Create a new logger
+
+```go
+appLogger := hclog.New(&hclog.LoggerOptions{
+ Name: "my-app",
+ Level: hclog.LevelFromString("DEBUG"),
+})
+```
+
+### Emit an Info level message with 2 key/value pairs
+
+```go
+input := "5.5"
+_, err := strconv.ParseInt(input, 10, 32)
+if err != nil {
+ appLogger.Info("Invalid input for ParseInt", "input", input, "error", err)
+}
+```
+
+```text
+... [INFO ] my-app: Invalid input for ParseInt: input=5.5 error="strconv.ParseInt: parsing "5.5": invalid syntax"
+```
+
+### Create a new Logger for a major subsystem
+
+```go
+subsystemLogger := appLogger.Named("transport")
+subsystemLogger.Info("we are transporting something")
+```
+
+```text
+... [INFO ] my-app.transport: we are transporting something
+```
+
+Notice that logs emitted by `subsystemLogger` contain `my-app.transport`,
+reflecting both the application and subsystem names.
+
+### Create a new Logger with fixed key/value pairs
+
+Using `With()` will include a specific key-value pair in all messages emitted
+by that logger.
+
+```go
+requestID := "5fb446b6-6eba-821d-df1b-cd7501b6a363"
+requestLogger := subsystemLogger.With("request", requestID)
+requestLogger.Info("we are transporting a request")
+```
+
+```text
+... [INFO ] my-app.transport: we are transporting a request: request=5fb446b6-6eba-821d-df1b-cd7501b6a363
+```
+
+This allows sub Loggers to be context specific without having to thread that
+into all the callers.
+
+### Using `hclog.Fmt()`
+
+```go
+var int totalBandwidth = 200
+appLogger.Info("total bandwidth exceeded", "bandwidth", hclog.Fmt("%d GB/s", totalBandwidth))
+```
+
+```text
+... [INFO ] my-app: total bandwidth exceeded: bandwidth="200 GB/s"
+```
+
+### Use this with code that uses the standard library logger
+
+If you want to use the standard library's `log.Logger` interface you can wrap
+`hclog.Logger` by calling the `StandardLogger()` method. This allows you to use
+it with the familiar `Println()`, `Printf()`, etc. For example:
+
+```go
+stdLogger := appLogger.StandardLogger(&hclog.StandardLoggerOptions{
+ InferLevels: true,
+})
+// Printf() is provided by stdlib log.Logger interface, not hclog.Logger
+stdLogger.Printf("[DEBUG] %+v", stdLogger)
+```
+
+```text
+... [DEBUG] my-app: &{mu:{state:0 sema:0} prefix: flag:0 out:0xc42000a0a0 buf:[]}
+```
+
+Notice that if `appLogger` is initialized with the `INFO` log level _and_ you
+specify `InferLevels: true`, you will not see any output here. You must change
+`appLogger` to `DEBUG` to see output. See the docs for more information.
--- /dev/null
+package hclog
+
+import (
+ "sync"
+)
+
+var (
+ protect sync.Once
+ def Logger
+
+ // The options used to create the Default logger. These are
+ // read only when the Default logger is created, so set them
+ // as soon as the process starts.
+ DefaultOptions = &LoggerOptions{
+ Level: DefaultLevel,
+ Output: DefaultOutput,
+ }
+)
+
+// Return a logger that is held globally. This can be a good starting
+// place, and then you can use .With() and .Name() to create sub-loggers
+// to be used in more specific contexts.
+func Default() Logger {
+ protect.Do(func() {
+ def = New(DefaultOptions)
+ })
+
+ return def
+}
+
+// A short alias for Default()
+func L() Logger {
+ return Default()
+}
--- /dev/null
+package hclog
+
+import (
+ "bufio"
+ "encoding"
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+var (
+ _levelToBracket = map[Level]string{
+ Debug: "[DEBUG]",
+ Trace: "[TRACE]",
+ Info: "[INFO ]",
+ Warn: "[WARN ]",
+ Error: "[ERROR]",
+ }
+)
+
+// Given the options (nil for defaults), create a new Logger
+func New(opts *LoggerOptions) Logger {
+ if opts == nil {
+ opts = &LoggerOptions{}
+ }
+
+ output := opts.Output
+ if output == nil {
+ output = os.Stderr
+ }
+
+ level := opts.Level
+ if level == NoLevel {
+ level = DefaultLevel
+ }
+
+ mtx := opts.Mutex
+ if mtx == nil {
+ mtx = new(sync.Mutex)
+ }
+
+ ret := &intLogger{
+ m: mtx,
+ json: opts.JSONFormat,
+ caller: opts.IncludeLocation,
+ name: opts.Name,
+ timeFormat: TimeFormat,
+ w: bufio.NewWriter(output),
+ level: new(int32),
+ }
+ if opts.TimeFormat != "" {
+ ret.timeFormat = opts.TimeFormat
+ }
+ atomic.StoreInt32(ret.level, int32(level))
+ return ret
+}
+
+// The internal logger implementation. Internal in that it is defined entirely
+// by this package.
+type intLogger struct {
+ json bool
+ caller bool
+ name string
+ timeFormat string
+
+ // this is a pointer so that it's shared by any derived loggers, since
+ // those derived loggers share the bufio.Writer as well.
+ m *sync.Mutex
+ w *bufio.Writer
+ level *int32
+
+ implied []interface{}
+}
+
+// Make sure that intLogger is a Logger
+var _ Logger = &intLogger{}
+
+// The time format to use for logging. This is a version of RFC3339 that
+// contains millisecond precision
+const TimeFormat = "2006-01-02T15:04:05.000Z0700"
+
+// Log a message and a set of key/value pairs if the given level is at
+// or more severe that the threshold configured in the Logger.
+func (z *intLogger) Log(level Level, msg string, args ...interface{}) {
+ if level < Level(atomic.LoadInt32(z.level)) {
+ return
+ }
+
+ t := time.Now()
+
+ z.m.Lock()
+ defer z.m.Unlock()
+
+ if z.json {
+ z.logJson(t, level, msg, args...)
+ } else {
+ z.log(t, level, msg, args...)
+ }
+
+ z.w.Flush()
+}
+
+// Cleanup a path by returning the last 2 segments of the path only.
+func trimCallerPath(path string) string {
+ // lovely borrowed from zap
+ // nb. To make sure we trim the path correctly on Windows too, we
+ // counter-intuitively need to use '/' and *not* os.PathSeparator here,
+ // because the path given originates from Go stdlib, specifically
+ // runtime.Caller() which (as of Mar/17) returns forward slashes even on
+ // Windows.
+ //
+ // See https://github.com/golang/go/issues/3335
+ // and https://github.com/golang/go/issues/18151
+ //
+ // for discussion on the issue on Go side.
+ //
+
+ // Find the last separator.
+ //
+ idx := strings.LastIndexByte(path, '/')
+ if idx == -1 {
+ return path
+ }
+
+ // Find the penultimate separator.
+ idx = strings.LastIndexByte(path[:idx], '/')
+ if idx == -1 {
+ return path
+ }
+
+ return path[idx+1:]
+}
+
+// Non-JSON logging format function
+func (z *intLogger) log(t time.Time, level Level, msg string, args ...interface{}) {
+ z.w.WriteString(t.Format(z.timeFormat))
+ z.w.WriteByte(' ')
+
+ s, ok := _levelToBracket[level]
+ if ok {
+ z.w.WriteString(s)
+ } else {
+ z.w.WriteString("[UNKN ]")
+ }
+
+ if z.caller {
+ if _, file, line, ok := runtime.Caller(3); ok {
+ z.w.WriteByte(' ')
+ z.w.WriteString(trimCallerPath(file))
+ z.w.WriteByte(':')
+ z.w.WriteString(strconv.Itoa(line))
+ z.w.WriteByte(':')
+ }
+ }
+
+ z.w.WriteByte(' ')
+
+ if z.name != "" {
+ z.w.WriteString(z.name)
+ z.w.WriteString(": ")
+ }
+
+ z.w.WriteString(msg)
+
+ args = append(z.implied, args...)
+
+ var stacktrace CapturedStacktrace
+
+ if args != nil && len(args) > 0 {
+ if len(args)%2 != 0 {
+ cs, ok := args[len(args)-1].(CapturedStacktrace)
+ if ok {
+ args = args[:len(args)-1]
+ stacktrace = cs
+ } else {
+ args = append(args, "<unknown>")
+ }
+ }
+
+ z.w.WriteByte(':')
+
+ FOR:
+ for i := 0; i < len(args); i = i + 2 {
+ var val string
+
+ switch st := args[i+1].(type) {
+ case string:
+ val = st
+ case int:
+ val = strconv.FormatInt(int64(st), 10)
+ case int64:
+ val = strconv.FormatInt(int64(st), 10)
+ case int32:
+ val = strconv.FormatInt(int64(st), 10)
+ case int16:
+ val = strconv.FormatInt(int64(st), 10)
+ case int8:
+ val = strconv.FormatInt(int64(st), 10)
+ case uint:
+ val = strconv.FormatUint(uint64(st), 10)
+ case uint64:
+ val = strconv.FormatUint(uint64(st), 10)
+ case uint32:
+ val = strconv.FormatUint(uint64(st), 10)
+ case uint16:
+ val = strconv.FormatUint(uint64(st), 10)
+ case uint8:
+ val = strconv.FormatUint(uint64(st), 10)
+ case CapturedStacktrace:
+ stacktrace = st
+ continue FOR
+ case Format:
+ val = fmt.Sprintf(st[0].(string), st[1:]...)
+ default:
+ val = fmt.Sprintf("%v", st)
+ }
+
+ z.w.WriteByte(' ')
+ z.w.WriteString(args[i].(string))
+ z.w.WriteByte('=')
+
+ if strings.ContainsAny(val, " \t\n\r") {
+ z.w.WriteByte('"')
+ z.w.WriteString(val)
+ z.w.WriteByte('"')
+ } else {
+ z.w.WriteString(val)
+ }
+ }
+ }
+
+ z.w.WriteString("\n")
+
+ if stacktrace != "" {
+ z.w.WriteString(string(stacktrace))
+ }
+}
+
+// JSON logging function
+func (z *intLogger) logJson(t time.Time, level Level, msg string, args ...interface{}) {
+ vals := map[string]interface{}{
+ "@message": msg,
+ "@timestamp": t.Format("2006-01-02T15:04:05.000000Z07:00"),
+ }
+
+ var levelStr string
+ switch level {
+ case Error:
+ levelStr = "error"
+ case Warn:
+ levelStr = "warn"
+ case Info:
+ levelStr = "info"
+ case Debug:
+ levelStr = "debug"
+ case Trace:
+ levelStr = "trace"
+ default:
+ levelStr = "all"
+ }
+
+ vals["@level"] = levelStr
+
+ if z.name != "" {
+ vals["@module"] = z.name
+ }
+
+ if z.caller {
+ if _, file, line, ok := runtime.Caller(3); ok {
+ vals["@caller"] = fmt.Sprintf("%s:%d", file, line)
+ }
+ }
+
+ args = append(z.implied, args...)
+
+ if args != nil && len(args) > 0 {
+ if len(args)%2 != 0 {
+ cs, ok := args[len(args)-1].(CapturedStacktrace)
+ if ok {
+ args = args[:len(args)-1]
+ vals["stacktrace"] = cs
+ } else {
+ args = append(args, "<unknown>")
+ }
+ }
+
+ for i := 0; i < len(args); i = i + 2 {
+ if _, ok := args[i].(string); !ok {
+ // As this is the logging function not much we can do here
+ // without injecting into logs...
+ continue
+ }
+ val := args[i+1]
+ switch sv := val.(type) {
+ case error:
+ // Check if val is of type error. If error type doesn't
+ // implement json.Marshaler or encoding.TextMarshaler
+ // then set val to err.Error() so that it gets marshaled
+ switch sv.(type) {
+ case json.Marshaler, encoding.TextMarshaler:
+ default:
+ val = sv.Error()
+ }
+ case Format:
+ val = fmt.Sprintf(sv[0].(string), sv[1:]...)
+ }
+
+ vals[args[i].(string)] = val
+ }
+ }
+
+ err := json.NewEncoder(z.w).Encode(vals)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Emit the message and args at DEBUG level
+func (z *intLogger) Debug(msg string, args ...interface{}) {
+ z.Log(Debug, msg, args...)
+}
+
+// Emit the message and args at TRACE level
+func (z *intLogger) Trace(msg string, args ...interface{}) {
+ z.Log(Trace, msg, args...)
+}
+
+// Emit the message and args at INFO level
+func (z *intLogger) Info(msg string, args ...interface{}) {
+ z.Log(Info, msg, args...)
+}
+
+// Emit the message and args at WARN level
+func (z *intLogger) Warn(msg string, args ...interface{}) {
+ z.Log(Warn, msg, args...)
+}
+
+// Emit the message and args at ERROR level
+func (z *intLogger) Error(msg string, args ...interface{}) {
+ z.Log(Error, msg, args...)
+}
+
+// Indicate that the logger would emit TRACE level logs
+func (z *intLogger) IsTrace() bool {
+ return Level(atomic.LoadInt32(z.level)) == Trace
+}
+
+// Indicate that the logger would emit DEBUG level logs
+func (z *intLogger) IsDebug() bool {
+ return Level(atomic.LoadInt32(z.level)) <= Debug
+}
+
+// Indicate that the logger would emit INFO level logs
+func (z *intLogger) IsInfo() bool {
+ return Level(atomic.LoadInt32(z.level)) <= Info
+}
+
+// Indicate that the logger would emit WARN level logs
+func (z *intLogger) IsWarn() bool {
+ return Level(atomic.LoadInt32(z.level)) <= Warn
+}
+
+// Indicate that the logger would emit ERROR level logs
+func (z *intLogger) IsError() bool {
+ return Level(atomic.LoadInt32(z.level)) <= Error
+}
+
+// Return a sub-Logger for which every emitted log message will contain
+// the given key/value pairs. This is used to create a context specific
+// Logger.
+func (z *intLogger) With(args ...interface{}) Logger {
+ if len(args)%2 != 0 {
+ panic("With() call requires paired arguments")
+ }
+
+ var nz intLogger = *z
+
+ result := make(map[string]interface{}, len(z.implied)+len(args))
+ keys := make([]string, 0, len(z.implied)+len(args))
+
+ // Read existing args, store map and key for consistent sorting
+ for i := 0; i < len(z.implied); i += 2 {
+ key := z.implied[i].(string)
+ keys = append(keys, key)
+ result[key] = z.implied[i+1]
+ }
+ // Read new args, store map and key for consistent sorting
+ for i := 0; i < len(args); i += 2 {
+ key := args[i].(string)
+ _, exists := result[key]
+ if !exists {
+ keys = append(keys, key)
+ }
+ result[key] = args[i+1]
+ }
+
+ // Sort keys to be consistent
+ sort.Strings(keys)
+
+ nz.implied = make([]interface{}, 0, len(z.implied)+len(args))
+ for _, k := range keys {
+ nz.implied = append(nz.implied, k)
+ nz.implied = append(nz.implied, result[k])
+ }
+
+ return &nz
+}
+
+// Create a new sub-Logger that a name decending from the current name.
+// This is used to create a subsystem specific Logger.
+func (z *intLogger) Named(name string) Logger {
+ var nz intLogger = *z
+
+ if nz.name != "" {
+ nz.name = nz.name + "." + name
+ } else {
+ nz.name = name
+ }
+
+ return &nz
+}
+
+// Create a new sub-Logger with an explicit name. This ignores the current
+// name. This is used to create a standalone logger that doesn't fall
+// within the normal hierarchy.
+func (z *intLogger) ResetNamed(name string) Logger {
+ var nz intLogger = *z
+
+ nz.name = name
+
+ return &nz
+}
+
+// Update the logging level on-the-fly. This will affect all subloggers as
+// well.
+func (z *intLogger) SetLevel(level Level) {
+ atomic.StoreInt32(z.level, int32(level))
+}
+
+// Create a *log.Logger that will send it's data through this Logger. This
+// allows packages that expect to be using the standard library log to actually
+// use this logger.
+func (z *intLogger) StandardLogger(opts *StandardLoggerOptions) *log.Logger {
+ if opts == nil {
+ opts = &StandardLoggerOptions{}
+ }
+
+ return log.New(&stdlogAdapter{z, opts.InferLevels}, "", 0)
+}
--- /dev/null
+package hclog
+
+import (
+ "io"
+ "log"
+ "os"
+ "strings"
+ "sync"
+)
+
+var (
+ DefaultOutput = os.Stderr
+ DefaultLevel = Info
+)
+
+type Level int32
+
+const (
+ // This is a special level used to indicate that no level has been
+ // set and allow for a default to be used.
+ NoLevel Level = 0
+
+ // The most verbose level. Intended to be used for the tracing of actions
+ // in code, such as function enters/exits, etc.
+ Trace Level = 1
+
+ // For programmer lowlevel analysis.
+ Debug Level = 2
+
+ // For information about steady state operations.
+ Info Level = 3
+
+ // For information about rare but handled events.
+ Warn Level = 4
+
+ // For information about unrecoverable events.
+ Error Level = 5
+)
+
+// When processing a value of this type, the logger automatically treats the first
+// argument as a Printf formatting string and passes the rest as the values to be
+// formatted. For example: L.Info(Fmt{"%d beans/day", beans}). This is a simple
+// convience type for when formatting is required.
+type Format []interface{}
+
+// Fmt returns a Format type. This is a convience function for creating a Format
+// type.
+func Fmt(str string, args ...interface{}) Format {
+ return append(Format{str}, args...)
+}
+
+// LevelFromString returns a Level type for the named log level, or "NoLevel" if
+// the level string is invalid. This facilitates setting the log level via
+// config or environment variable by name in a predictable way.
+func LevelFromString(levelStr string) Level {
+ // We don't care about case. Accept "INFO" or "info"
+ levelStr = strings.ToLower(strings.TrimSpace(levelStr))
+ switch levelStr {
+ case "trace":
+ return Trace
+ case "debug":
+ return Debug
+ case "info":
+ return Info
+ case "warn":
+ return Warn
+ case "error":
+ return Error
+ default:
+ return NoLevel
+ }
+}
+
+// The main Logger interface. All code should code against this interface only.
+type Logger interface {
+ // Args are alternating key, val pairs
+ // keys must be strings
+ // vals can be any type, but display is implementation specific
+ // Emit a message and key/value pairs at the TRACE level
+ Trace(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the DEBUG level
+ Debug(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the INFO level
+ Info(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the WARN level
+ Warn(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the ERROR level
+ Error(msg string, args ...interface{})
+
+ // Indicate if TRACE logs would be emitted. This and the other Is* guards
+ // are used to elide expensive logging code based on the current level.
+ IsTrace() bool
+
+ // Indicate if DEBUG logs would be emitted. This and the other Is* guards
+ IsDebug() bool
+
+ // Indicate if INFO logs would be emitted. This and the other Is* guards
+ IsInfo() bool
+
+ // Indicate if WARN logs would be emitted. This and the other Is* guards
+ IsWarn() bool
+
+ // Indicate if ERROR logs would be emitted. This and the other Is* guards
+ IsError() bool
+
+ // Creates a sublogger that will always have the given key/value pairs
+ With(args ...interface{}) Logger
+
+ // Create a logger that will prepend the name string on the front of all messages.
+ // If the logger already has a name, the new value will be appended to the current
+ // name. That way, a major subsystem can use this to decorate all it's own logs
+ // without losing context.
+ Named(name string) Logger
+
+ // Create a logger that will prepend the name string on the front of all messages.
+ // This sets the name of the logger to the value directly, unlike Named which honor
+ // the current name as well.
+ ResetNamed(name string) Logger
+
+ // Updates the level. This should affect all sub-loggers as well. If an
+ // implementation cannot update the level on the fly, it should no-op.
+ SetLevel(level Level)
+
+ // Return a value that conforms to the stdlib log.Logger interface
+ StandardLogger(opts *StandardLoggerOptions) *log.Logger
+}
+
+type StandardLoggerOptions struct {
+ // Indicate that some minimal parsing should be done on strings to try
+ // and detect their level and re-emit them.
+ // This supports the strings like [ERROR], [ERR] [TRACE], [WARN], [INFO],
+ // [DEBUG] and strip it off before reapplying it.
+ InferLevels bool
+}
+
+type LoggerOptions struct {
+ // Name of the subsystem to prefix logs with
+ Name string
+
+ // The threshold for the logger. Anything less severe is supressed
+ Level Level
+
+ // Where to write the logs to. Defaults to os.Stdout if nil
+ Output io.Writer
+
+ // An optional mutex pointer in case Output is shared
+ Mutex *sync.Mutex
+
+ // Control if the output should be in JSON.
+ JSONFormat bool
+
+ // Include file and line information in each log line
+ IncludeLocation bool
+
+ // The time format to use instead of the default
+ TimeFormat string
+}
--- /dev/null
+package hclog
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestLogger(t *testing.T) {
+ t.Run("formats log entries", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ })
+
+ logger.Info("this is test", "who", "programmer", "why", "testing")
+
+ str := buf.String()
+
+ dataIdx := strings.IndexByte(str, ' ')
+
+ // ts := str[:dataIdx]
+ rest := str[dataIdx+1:]
+
+ assert.Equal(t, "[INFO ] test: this is test: who=programmer why=testing\n", rest)
+ })
+
+ t.Run("quotes values with spaces", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ })
+
+ logger.Info("this is test", "who", "programmer", "why", "testing is fun")
+
+ str := buf.String()
+
+ dataIdx := strings.IndexByte(str, ' ')
+
+ // ts := str[:dataIdx]
+ rest := str[dataIdx+1:]
+
+ assert.Equal(t, "[INFO ] test: this is test: who=programmer why=\"testing is fun\"\n", rest)
+ })
+
+ t.Run("outputs stack traces", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ })
+
+ logger.Info("who", "programmer", "why", "testing", Stacktrace())
+
+ lines := strings.Split(buf.String(), "\n")
+
+ require.True(t, len(lines) > 1)
+
+ assert.Equal(t, "github.com/hashicorp/go-hclog.Stacktrace", lines[1])
+ })
+
+ t.Run("outputs stack traces with it's given a name", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ })
+
+ logger.Info("who", "programmer", "why", "testing", "foo", Stacktrace())
+
+ lines := strings.Split(buf.String(), "\n")
+
+ require.True(t, len(lines) > 1)
+
+ assert.Equal(t, "github.com/hashicorp/go-hclog.Stacktrace", lines[1])
+ })
+
+ t.Run("includes the caller location", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ IncludeLocation: true,
+ })
+
+ logger.Info("this is test", "who", "programmer", "why", "testing is fun")
+
+ str := buf.String()
+
+ dataIdx := strings.IndexByte(str, ' ')
+
+ // ts := str[:dataIdx]
+ rest := str[dataIdx+1:]
+
+ // This test will break if you move this around, it's line dependent, just fyi
+ assert.Equal(t, "[INFO ] go-hclog/logger_test.go:101: test: this is test: who=programmer why=\"testing is fun\"\n", rest)
+ })
+
+ t.Run("prefixes the name", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ // No name!
+ Output: &buf,
+ })
+
+ logger.Info("this is test")
+ str := buf.String()
+ dataIdx := strings.IndexByte(str, ' ')
+ rest := str[dataIdx+1:]
+ assert.Equal(t, "[INFO ] this is test\n", rest)
+
+ buf.Reset()
+
+ another := logger.Named("sublogger")
+ another.Info("this is test")
+ str = buf.String()
+ dataIdx = strings.IndexByte(str, ' ')
+ rest = str[dataIdx+1:]
+ assert.Equal(t, "[INFO ] sublogger: this is test\n", rest)
+ })
+
+ t.Run("use a different time format", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ TimeFormat: time.Kitchen,
+ })
+
+ logger.Info("this is test", "who", "programmer", "why", "testing is fun")
+
+ str := buf.String()
+
+ dataIdx := strings.IndexByte(str, ' ')
+
+ assert.Equal(t, str[:dataIdx], time.Now().Format(time.Kitchen))
+ })
+
+ t.Run("use with", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ rootLogger := New(&LoggerOptions{
+ Name: "with_test",
+ Output: &buf,
+ })
+
+ // Build the root logger in two steps, which triggers a slice capacity increase
+ // and is part of the test for inadvertant slice aliasing.
+ rootLogger = rootLogger.With("a", 1, "b", 2)
+ rootLogger = rootLogger.With("c", 3)
+
+ // Derive two new loggers which should be completely independent
+ derived1 := rootLogger.With("cat", 30)
+ derived2 := rootLogger.With("dog", 40)
+
+ derived1.Info("test1")
+ output := buf.String()
+ dataIdx := strings.IndexByte(output, ' ')
+ assert.Equal(t, "[INFO ] with_test: test1: a=1 b=2 c=3 cat=30\n", output[dataIdx+1:])
+
+ buf.Reset()
+
+ derived2.Info("test2")
+ output = buf.String()
+ dataIdx = strings.IndexByte(output, ' ')
+ assert.Equal(t, "[INFO ] with_test: test2: a=1 b=2 c=3 dog=40\n", output[dataIdx+1:])
+ })
+
+ t.Run("unpaired with", func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("expected panic")
+ }
+ }()
+
+ var buf bytes.Buffer
+
+ rootLogger := New(&LoggerOptions{
+ Name: "with_test",
+ Output: &buf,
+ })
+
+ rootLogger = rootLogger.With("a")
+ })
+
+ t.Run("use with and log", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ rootLogger := New(&LoggerOptions{
+ Name: "with_test",
+ Output: &buf,
+ })
+
+ // Build the root logger in two steps, which triggers a slice capacity increase
+ // and is part of the test for inadvertant slice aliasing.
+ rootLogger = rootLogger.With("a", 1, "b", 2)
+ // This line is here to test that when calling With with the same key,
+ // only the last value remains (see issue #21)
+ rootLogger = rootLogger.With("c", 4)
+ rootLogger = rootLogger.With("c", 3)
+
+ // Derive another logger which should be completely independent of rootLogger
+ derived := rootLogger.With("cat", 30)
+
+ rootLogger.Info("root_test", "bird", 10)
+ output := buf.String()
+ dataIdx := strings.IndexByte(output, ' ')
+ assert.Equal(t, "[INFO ] with_test: root_test: a=1 b=2 c=3 bird=10\n", output[dataIdx+1:])
+
+ buf.Reset()
+
+ derived.Info("derived_test")
+ output = buf.String()
+ dataIdx = strings.IndexByte(output, ' ')
+ assert.Equal(t, "[INFO ] with_test: derived_test: a=1 b=2 c=3 cat=30\n", output[dataIdx+1:])
+ })
+
+ t.Run("use with and log and change levels", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ rootLogger := New(&LoggerOptions{
+ Name: "with_test",
+ Output: &buf,
+ Level: Warn,
+ })
+
+ // Build the root logger in two steps, which triggers a slice capacity increase
+ // and is part of the test for inadvertant slice aliasing.
+ rootLogger = rootLogger.With("a", 1, "b", 2)
+ rootLogger = rootLogger.With("c", 3)
+
+ // Derive another logger which should be completely independent of rootLogger
+ derived := rootLogger.With("cat", 30)
+
+ rootLogger.Info("root_test", "bird", 10)
+ output := buf.String()
+ if output != "" {
+ t.Fatalf("unexpected output: %s", output)
+ }
+
+ buf.Reset()
+
+ derived.Info("derived_test")
+ output = buf.String()
+ if output != "" {
+ t.Fatalf("unexpected output: %s", output)
+ }
+
+ derived.SetLevel(Info)
+
+ rootLogger.Info("root_test", "bird", 10)
+ output = buf.String()
+ dataIdx := strings.IndexByte(output, ' ')
+ assert.Equal(t, "[INFO ] with_test: root_test: a=1 b=2 c=3 bird=10\n", output[dataIdx+1:])
+
+ buf.Reset()
+
+ derived.Info("derived_test")
+ output = buf.String()
+ dataIdx = strings.IndexByte(output, ' ')
+ assert.Equal(t, "[INFO ] with_test: derived_test: a=1 b=2 c=3 cat=30\n", output[dataIdx+1:])
+ })
+
+ t.Run("supports Printf style expansions when requested", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ })
+
+ logger.Info("this is test", "production", Fmt("%d beans/day", 12))
+
+ str := buf.String()
+
+ dataIdx := strings.IndexByte(str, ' ')
+
+ // ts := str[:dataIdx]
+ rest := str[dataIdx+1:]
+
+ assert.Equal(t, "[INFO ] test: this is test: production=\"12 beans/day\"\n", rest)
+ })
+}
+
+func TestLogger_JSON(t *testing.T) {
+ t.Run("json formatting", func(t *testing.T) {
+ var buf bytes.Buffer
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ JSONFormat: true,
+ })
+
+ logger.Info("this is test", "who", "programmer", "why", "testing is fun")
+
+ b := buf.Bytes()
+
+ var raw map[string]interface{}
+ if err := json.Unmarshal(b, &raw); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "this is test", raw["@message"])
+ assert.Equal(t, "programmer", raw["who"])
+ assert.Equal(t, "testing is fun", raw["why"])
+ })
+
+ t.Run("json formatting with", func(t *testing.T) {
+ var buf bytes.Buffer
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ JSONFormat: true,
+ })
+ logger = logger.With("cat", "in the hat", "dog", 42)
+
+ logger.Info("this is test", "who", "programmer", "why", "testing is fun")
+
+ b := buf.Bytes()
+
+ var raw map[string]interface{}
+ if err := json.Unmarshal(b, &raw); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "this is test", raw["@message"])
+ assert.Equal(t, "programmer", raw["who"])
+ assert.Equal(t, "testing is fun", raw["why"])
+ assert.Equal(t, "in the hat", raw["cat"])
+ assert.Equal(t, float64(42), raw["dog"])
+ })
+
+ t.Run("json formatting error type", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ JSONFormat: true,
+ })
+
+ errMsg := errors.New("this is an error")
+ logger.Info("this is test", "who", "programmer", "err", errMsg)
+
+ b := buf.Bytes()
+
+ var raw map[string]interface{}
+ if err := json.Unmarshal(b, &raw); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "this is test", raw["@message"])
+ assert.Equal(t, "programmer", raw["who"])
+ assert.Equal(t, errMsg.Error(), raw["err"])
+ })
+
+ t.Run("json formatting custom error type json marshaler", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ JSONFormat: true,
+ })
+
+ errMsg := &customErrJSON{"this is an error"}
+ rawMsg, err := errMsg.MarshalJSON()
+ if err != nil {
+ t.Fatal(err)
+ }
+ expectedMsg, err := strconv.Unquote(string(rawMsg))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ logger.Info("this is test", "who", "programmer", "err", errMsg)
+
+ b := buf.Bytes()
+
+ var raw map[string]interface{}
+ if err := json.Unmarshal(b, &raw); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "this is test", raw["@message"])
+ assert.Equal(t, "programmer", raw["who"])
+ assert.Equal(t, expectedMsg, raw["err"])
+ })
+
+ t.Run("json formatting custom error type text marshaler", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ JSONFormat: true,
+ })
+
+ errMsg := &customErrText{"this is an error"}
+ rawMsg, err := errMsg.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ expectedMsg := string(rawMsg)
+
+ logger.Info("this is test", "who", "programmer", "err", errMsg)
+
+ b := buf.Bytes()
+
+ var raw map[string]interface{}
+ if err := json.Unmarshal(b, &raw); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "this is test", raw["@message"])
+ assert.Equal(t, "programmer", raw["who"])
+ assert.Equal(t, expectedMsg, raw["err"])
+ })
+
+ t.Run("supports Printf style expansions when requested", func(t *testing.T) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ JSONFormat: true,
+ })
+
+ logger.Info("this is test", "production", Fmt("%d beans/day", 12))
+
+ b := buf.Bytes()
+
+ var raw map[string]interface{}
+ if err := json.Unmarshal(b, &raw); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "this is test", raw["@message"])
+ assert.Equal(t, "12 beans/day", raw["production"])
+ })
+}
+
+type customErrJSON struct {
+ Message string
+}
+
+// error impl.
+func (c *customErrJSON) Error() string {
+ return c.Message
+}
+
+// json.Marshaler impl.
+func (c customErrJSON) MarshalJSON() ([]byte, error) {
+ return []byte(strconv.Quote(fmt.Sprintf("json-marshaler: %s", c.Message))), nil
+}
+
+type customErrText struct {
+ Message string
+}
+
+// error impl.
+func (c *customErrText) Error() string {
+ return c.Message
+}
+
+// text.Marshaler impl.
+func (c customErrText) MarshalText() ([]byte, error) {
+ return []byte(fmt.Sprintf("text-marshaler: %s", c.Message)), nil
+}
+
+func BenchmarkLogger(b *testing.B) {
+ b.Run("info with 10 pairs", func(b *testing.B) {
+ var buf bytes.Buffer
+
+ logger := New(&LoggerOptions{
+ Name: "test",
+ Output: &buf,
+ IncludeLocation: true,
+ })
+
+ for i := 0; i < b.N; i++ {
+ logger.Info("this is some message",
+ "name", "foo",
+ "what", "benchmarking yourself",
+ "why", "to see what's slow",
+ "k4", "value",
+ "k5", "value",
+ "k6", "value",
+ "k7", "value",
+ "k8", "value",
+ "k9", "value",
+ "k10", "value",
+ )
+ }
+ })
+}
--- /dev/null
+package hclog
+
+import (
+ "io/ioutil"
+ "log"
+)
+
+// NewNullLogger instantiates a Logger for which all calls
+// will succeed without doing anything.
+// Useful for testing purposes.
+func NewNullLogger() Logger {
+ return &nullLogger{}
+}
+
+type nullLogger struct{}
+
+func (l *nullLogger) Trace(msg string, args ...interface{}) {}
+
+func (l *nullLogger) Debug(msg string, args ...interface{}) {}
+
+func (l *nullLogger) Info(msg string, args ...interface{}) {}
+
+func (l *nullLogger) Warn(msg string, args ...interface{}) {}
+
+func (l *nullLogger) Error(msg string, args ...interface{}) {}
+
+func (l *nullLogger) IsTrace() bool { return false }
+
+func (l *nullLogger) IsDebug() bool { return false }
+
+func (l *nullLogger) IsInfo() bool { return false }
+
+func (l *nullLogger) IsWarn() bool { return false }
+
+func (l *nullLogger) IsError() bool { return false }
+
+func (l *nullLogger) With(args ...interface{}) Logger { return l }
+
+func (l *nullLogger) Named(name string) Logger { return l }
+
+func (l *nullLogger) ResetNamed(name string) Logger { return l }
+
+func (l *nullLogger) SetLevel(level Level) {}
+
+func (l *nullLogger) StandardLogger(opts *StandardLoggerOptions) *log.Logger {
+ return log.New(ioutil.Discard, "", log.LstdFlags)
+}
--- /dev/null
+package hclog
+
+import (
+ "testing"
+ "github.com/stretchr/testify/assert"
+)
+
+var logger = NewNullLogger()
+
+func TestNullLoggerIsEfficient(t *testing.T) {
+ // Since statements like "IsWarn()", "IsError()", etc. are used to gate
+ // actually writing warning and error statements, the null logger will
+ // be faster and more efficient if it always returns false for these calls.
+ assert.False(t, logger.IsTrace())
+ assert.False(t, logger.IsDebug())
+ assert.False(t, logger.IsInfo())
+ assert.False(t, logger.IsWarn())
+ assert.False(t, logger.IsError())
+}
+
+func TestNullLoggerReturnsNullLoggers(t *testing.T) {
+
+ // Sometimes the logger is asked to return subloggers.
+ // These should also be a nullLogger.
+
+ subLogger := logger.With()
+ _, ok := subLogger.(*nullLogger)
+ assert.True(t, ok)
+
+ subLogger = logger.Named("")
+ _, ok = subLogger.(*nullLogger)
+ assert.True(t, ok)
+
+ subLogger = logger.ResetNamed("")
+ _, ok = subLogger.(*nullLogger)
+ assert.True(t, ok)
+}
+
+func TestStandardLoggerIsntNil(t *testing.T) {
+ // Don't return a nil pointer for the standard logger,
+ // lest it cause a panic.
+ stdLogger := logger.StandardLogger(nil)
+ assert.NotEqual(t, nil, stdLogger)
+}
--- /dev/null
+// Copyright (c) 2016 Uber Technologies, Inc.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+package hclog
+
+import (
+ "bytes"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+var (
+ _stacktraceIgnorePrefixes = []string{
+ "runtime.goexit",
+ "runtime.main",
+ }
+ _stacktracePool = sync.Pool{
+ New: func() interface{} {
+ return newProgramCounters(64)
+ },
+ }
+)
+
+// A stacktrace gathered by a previous call to log.Stacktrace. If passed
+// to a logging function, the stacktrace will be appended.
+type CapturedStacktrace string
+
+// Gather a stacktrace of the current goroutine and return it to be passed
+// to a logging function.
+func Stacktrace() CapturedStacktrace {
+ return CapturedStacktrace(takeStacktrace())
+}
+
+func takeStacktrace() string {
+ programCounters := _stacktracePool.Get().(*programCounters)
+ defer _stacktracePool.Put(programCounters)
+
+ var buffer bytes.Buffer
+
+ for {
+ // Skip the call to runtime.Counters and takeStacktrace so that the
+ // program counters start at the caller of takeStacktrace.
+ n := runtime.Callers(2, programCounters.pcs)
+ if n < cap(programCounters.pcs) {
+ programCounters.pcs = programCounters.pcs[:n]
+ break
+ }
+ // Don't put the too-short counter slice back into the pool; this lets
+ // the pool adjust if we consistently take deep stacktraces.
+ programCounters = newProgramCounters(len(programCounters.pcs) * 2)
+ }
+
+ i := 0
+ frames := runtime.CallersFrames(programCounters.pcs)
+ for frame, more := frames.Next(); more; frame, more = frames.Next() {
+ if shouldIgnoreStacktraceFunction(frame.Function) {
+ continue
+ }
+ if i != 0 {
+ buffer.WriteByte('\n')
+ }
+ i++
+ buffer.WriteString(frame.Function)
+ buffer.WriteByte('\n')
+ buffer.WriteByte('\t')
+ buffer.WriteString(frame.File)
+ buffer.WriteByte(':')
+ buffer.WriteString(strconv.Itoa(int(frame.Line)))
+ }
+
+ return buffer.String()
+}
+
+func shouldIgnoreStacktraceFunction(function string) bool {
+ for _, prefix := range _stacktraceIgnorePrefixes {
+ if strings.HasPrefix(function, prefix) {
+ return true
+ }
+ }
+ return false
+}
+
+type programCounters struct {
+ pcs []uintptr
+}
+
+func newProgramCounters(size int) *programCounters {
+ return &programCounters{make([]uintptr, size)}
+}
--- /dev/null
+package hclog
+
+import (
+ "bytes"
+ "strings"
+)
+
+// Provides a io.Writer to shim the data out of *log.Logger
+// and back into our Logger. This is basically the only way to
+// build upon *log.Logger.
+type stdlogAdapter struct {
+ hl Logger
+ inferLevels bool
+}
+
+// Take the data, infer the levels if configured, and send it through
+// a regular Logger
+func (s *stdlogAdapter) Write(data []byte) (int, error) {
+ str := string(bytes.TrimRight(data, " \t\n"))
+
+ if s.inferLevels {
+ level, str := s.pickLevel(str)
+ switch level {
+ case Trace:
+ s.hl.Trace(str)
+ case Debug:
+ s.hl.Debug(str)
+ case Info:
+ s.hl.Info(str)
+ case Warn:
+ s.hl.Warn(str)
+ case Error:
+ s.hl.Error(str)
+ default:
+ s.hl.Info(str)
+ }
+ } else {
+ s.hl.Info(str)
+ }
+
+ return len(data), nil
+}
+
+// Detect, based on conventions, what log level this is
+func (s *stdlogAdapter) pickLevel(str string) (Level, string) {
+ switch {
+ case strings.HasPrefix(str, "[DEBUG]"):
+ return Debug, strings.TrimSpace(str[7:])
+ case strings.HasPrefix(str, "[TRACE]"):
+ return Trace, strings.TrimSpace(str[7:])
+ case strings.HasPrefix(str, "[INFO]"):
+ return Info, strings.TrimSpace(str[6:])
+ case strings.HasPrefix(str, "[WARN]"):
+ return Warn, strings.TrimSpace(str[7:])
+ case strings.HasPrefix(str, "[ERROR]"):
+ return Error, strings.TrimSpace(str[7:])
+ case strings.HasPrefix(str, "[ERR]"):
+ return Error, strings.TrimSpace(str[5:])
+ default:
+ return Info, str
+ }
+}
--- /dev/null
+package hclog
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestStdlogAdapter(t *testing.T) {
+ t.Run("picks debug level", func(t *testing.T) {
+ var s stdlogAdapter
+
+ level, rest := s.pickLevel("[DEBUG] coffee?")
+
+ assert.Equal(t, Debug, level)
+ assert.Equal(t, "coffee?", rest)
+ })
+
+ t.Run("picks trace level", func(t *testing.T) {
+ var s stdlogAdapter
+
+ level, rest := s.pickLevel("[TRACE] coffee?")
+
+ assert.Equal(t, Trace, level)
+ assert.Equal(t, "coffee?", rest)
+ })
+
+ t.Run("picks info level", func(t *testing.T) {
+ var s stdlogAdapter
+
+ level, rest := s.pickLevel("[INFO] coffee?")
+
+ assert.Equal(t, Info, level)
+ assert.Equal(t, "coffee?", rest)
+ })
+
+ t.Run("picks warn level", func(t *testing.T) {
+ var s stdlogAdapter
+
+ level, rest := s.pickLevel("[WARN] coffee?")
+
+ assert.Equal(t, Warn, level)
+ assert.Equal(t, "coffee?", rest)
+ })
+
+ t.Run("picks error level", func(t *testing.T) {
+ var s stdlogAdapter
+
+ level, rest := s.pickLevel("[ERROR] coffee?")
+
+ assert.Equal(t, Error, level)
+ assert.Equal(t, "coffee?", rest)
+ })
+
+ t.Run("picks error as err level", func(t *testing.T) {
+ var s stdlogAdapter
+
+ level, rest := s.pickLevel("[ERR] coffee?")
+
+ assert.Equal(t, Error, level)
+ assert.Equal(t, "coffee?", rest)
+ })
+}
--- /dev/null
+Mozilla Public License, version 2.0
+
+1. Definitions
+
+1.1. “Contributor”
+
+ means each individual or legal entity that creates, contributes to the
+ creation of, or owns Covered Software.
+
+1.2. “Contributor Version”
+
+ means the combination of the Contributions of others (if any) used by a
+ Contributor and that particular Contributor’s Contribution.
+
+1.3. “Contribution”
+
+ means Covered Software of a particular Contributor.
+
+1.4. “Covered Software”
+
+ means Source Code Form to which the initial Contributor has attached the
+ notice in Exhibit A, the Executable Form of such Source Code Form, and
+ Modifications of such Source Code Form, in each case including portions
+ thereof.
+
+1.5. “Incompatible With Secondary Licenses”
+ means
+
+ a. that the initial Contributor has attached the notice described in
+ Exhibit B to the Covered Software; or
+
+ b. that the Covered Software was made available under the terms of version
+ 1.1 or earlier of the License, but not also under the terms of a
+ Secondary License.
+
+1.6. “Executable Form”
+
+ means any form of the work other than Source Code Form.
+
+1.7. “Larger Work”
+
+ means a work that combines Covered Software with other material, in a separate
+ file or files, that is not Covered Software.
+
+1.8. “License”
+
+ means this document.
+
+1.9. “Licensable”
+
+ means having the right to grant, to the maximum extent possible, whether at the
+ time of the initial grant or subsequently, any and all of the rights conveyed by
+ this License.
+
+1.10. “Modifications”
+
+ means any of the following:
+
+ a. any file in Source Code Form that results from an addition to, deletion
+ from, or modification of the contents of Covered Software; or
+
+ b. any new file in Source Code Form that contains any Covered Software.
+
+1.11. “Patent Claims” of a Contributor
+
+ means any patent claim(s), including without limitation, method, process,
+ and apparatus claims, in any patent Licensable by such Contributor that
+ would be infringed, but for the grant of the License, by the making,
+ using, selling, offering for sale, having made, import, or transfer of
+ either its Contributions or its Contributor Version.
+
+1.12. “Secondary License”
+
+ means either the GNU General Public License, Version 2.0, the GNU Lesser
+ General Public License, Version 2.1, the GNU Affero General Public
+ License, Version 3.0, or any later versions of those licenses.
+
+1.13. “Source Code Form”
+
+ means the form of the work preferred for making modifications.
+
+1.14. “You” (or “Your”)
+
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, “You” includes any entity that controls, is
+ controlled by, or is under common control with You. For purposes of this
+ definition, “control” means (a) the power, direct or indirect, to cause
+ the direction or management of such entity, whether by contract or
+ otherwise, or (b) ownership of more than fifty percent (50%) of the
+ outstanding shares or beneficial ownership of such entity.
+
+
+2. License Grants and Conditions
+
+2.1. Grants
+
+ Each Contributor hereby grants You a world-wide, royalty-free,
+ non-exclusive license:
+
+ a. under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or as
+ part of a Larger Work; and
+
+ b. under Patent Claims of such Contributor to make, use, sell, offer for
+ sale, have made, import, and otherwise transfer either its Contributions
+ or its Contributor Version.
+
+2.2. Effective Date
+
+ The licenses granted in Section 2.1 with respect to any Contribution become
+ effective for each Contribution on the date the Contributor first distributes
+ such Contribution.
+
+2.3. Limitations on Grant Scope
+
+ The licenses granted in this Section 2 are the only rights granted under this
+ License. No additional rights or licenses will be implied from the distribution
+ or licensing of Covered Software under this License. Notwithstanding Section
+ 2.1(b) above, no patent license is granted by a Contributor:
+
+ a. for any code that a Contributor has removed from Covered Software; or
+
+ b. for infringements caused by: (i) Your and any other third party’s
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+ c. under Patent Claims infringed by Covered Software in the absence of its
+ Contributions.
+
+ This License does not grant any rights in the trademarks, service marks, or
+ logos of any Contributor (except as may be necessary to comply with the
+ notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+ No Contributor makes additional grants as a result of Your choice to
+ distribute the Covered Software under a subsequent version of this License
+ (see Section 10.2) or under the terms of a Secondary License (if permitted
+ under the terms of Section 3.3).
+
+2.5. Representation
+
+ Each Contributor represents that the Contributor believes its Contributions
+ are its original creation(s) or it has sufficient rights to grant the
+ rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+ This License is not intended to limit any rights You have under applicable
+ copyright doctrines of fair use, fair dealing, or other equivalents.
+
+2.7. Conditions
+
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
+ Section 2.1.
+
+
+3. Responsibilities
+
+3.1. Distribution of Source Form
+
+ All distribution of Covered Software in Source Code Form, including any
+ Modifications that You create or to which You contribute, must be under the
+ terms of this License. You must inform recipients that the Source Code Form
+ of the Covered Software is governed by the terms of this License, and how
+ they can obtain a copy of this License. You may not attempt to alter or
+ restrict the recipients’ rights in the Source Code Form.
+
+3.2. Distribution of Executable Form
+
+ If You distribute Covered Software in Executable Form then:
+
+ a. such Covered Software must also be made available in Source Code Form,
+ as described in Section 3.1, and You must inform recipients of the
+ Executable Form how they can obtain a copy of such Source Code Form by
+ reasonable means in a timely manner, at a charge no more than the cost
+ of distribution to the recipient; and
+
+ b. You may distribute such Executable Form under the terms of this License,
+ or sublicense it under different terms, provided that the license for
+ the Executable Form does not attempt to limit or alter the recipients’
+ rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+ You may create and distribute a Larger Work under terms of Your choice,
+ provided that You also comply with the requirements of this License for the
+ Covered Software. If the Larger Work is a combination of Covered Software
+ with a work governed by one or more Secondary Licenses, and the Covered
+ Software is not Incompatible With Secondary Licenses, this License permits
+ You to additionally distribute such Covered Software under the terms of
+ such Secondary License(s), so that the recipient of the Larger Work may, at
+ their option, further distribute the Covered Software under the terms of
+ either this License or such Secondary License(s).
+
+3.4. Notices
+
+ You may not remove or alter the substance of any license notices (including
+ copyright notices, patent notices, disclaimers of warranty, or limitations
+ of liability) contained within the Source Code Form of the Covered
+ Software, except that You may alter any license notices to the extent
+ required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+ You may choose to offer, and to charge a fee for, warranty, support,
+ indemnity or liability obligations to one or more recipients of Covered
+ Software. However, You may do so only on Your own behalf, and not on behalf
+ of any Contributor. You must make it absolutely clear that any such
+ warranty, support, indemnity, or liability obligation is offered by You
+ alone, and You hereby agree to indemnify every Contributor for any
+ liability incurred by such Contributor as a result of warranty, support,
+ indemnity or liability terms You offer. You may include additional
+ disclaimers of warranty and limitations of liability specific to any
+ jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+
+ If it is impossible for You to comply with any of the terms of this License
+ with respect to some or all of the Covered Software due to statute, judicial
+ order, or regulation then You must: (a) comply with the terms of this License
+ to the maximum extent possible; and (b) describe the limitations and the code
+ they affect. Such description must be placed in a text file included with all
+ distributions of the Covered Software under this License. Except to the
+ extent prohibited by statute or regulation, such description must be
+ sufficiently detailed for a recipient of ordinary skill to be able to
+ understand it.
+
+5. Termination
+
+5.1. The rights granted under this License will terminate automatically if You
+ fail to comply with any of its terms. However, if You become compliant,
+ then the rights granted under this License from a particular Contributor
+ are reinstated (a) provisionally, unless and until such Contributor
+ explicitly and finally terminates Your grants, and (b) on an ongoing basis,
+ if such Contributor fails to notify You of the non-compliance by some
+ reasonable means prior to 60 days after You have come back into compliance.
+ Moreover, Your grants from a particular Contributor are reinstated on an
+ ongoing basis if such Contributor notifies You of the non-compliance by
+ some reasonable means, this is the first time You have received notice of
+ non-compliance with this License from such Contributor, and You become
+ compliant prior to 30 days after Your receipt of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+ infringement claim (excluding declaratory judgment actions, counter-claims,
+ and cross-claims) alleging that a Contributor Version directly or
+ indirectly infringes any patent, then the rights granted to You by any and
+ all Contributors for the Covered Software under Section 2.1 of this License
+ shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
+ license agreements (excluding distributors and resellers) which have been
+ validly granted by You or Your distributors under this License prior to
+ termination shall survive termination.
+
+6. Disclaimer of Warranty
+
+ Covered Software is provided under this License on an “as is” basis, without
+ warranty of any kind, either expressed, implied, or statutory, including,
+ without limitation, warranties that the Covered Software is free of defects,
+ merchantable, fit for a particular purpose or non-infringing. The entire
+ risk as to the quality and performance of the Covered Software is with You.
+ Should any Covered Software prove defective in any respect, You (not any
+ Contributor) assume the cost of any necessary servicing, repair, or
+ correction. This disclaimer of warranty constitutes an essential part of this
+ License. No use of any Covered Software is authorized under this License
+ except under this disclaimer.
+
+7. Limitation of Liability
+
+ Under no circumstances and under no legal theory, whether tort (including
+ negligence), contract, or otherwise, shall any Contributor, or anyone who
+ distributes Covered Software as permitted above, be liable to You for any
+ direct, indirect, special, incidental, or consequential damages of any
+ character including, without limitation, damages for lost profits, loss of
+ goodwill, work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses, even if such party shall have been
+ informed of the possibility of such damages. This limitation of liability
+ shall not apply to liability for death or personal injury resulting from such
+ party’s negligence to the extent applicable law prohibits such limitation.
+ Some jurisdictions do not allow the exclusion or limitation of incidental or
+ consequential damages, so this exclusion and limitation may not apply to You.
+
+8. Litigation
+
+ Any litigation relating to this License may be brought only in the courts of
+ a jurisdiction where the defendant maintains its principal place of business
+ and such litigation shall be governed by laws of that jurisdiction, without
+ reference to its conflict-of-law provisions. Nothing in this Section shall
+ prevent a party’s ability to bring cross-claims or counter-claims.
+
+9. Miscellaneous
+
+ This License represents the complete agreement concerning the subject matter
+ hereof. If any provision of this License is held to be unenforceable, such
+ provision shall be reformed only to the extent necessary to make it
+ enforceable. Any law or regulation which provides that the language of a
+ contract shall be construed against the drafter shall not be used to construe
+ this License against a Contributor.
+
+
+10. Versions of the License
+
+10.1. New Versions
+
+ Mozilla Foundation is the license steward. Except as provided in Section
+ 10.3, no one other than the license steward has the right to modify or
+ publish new versions of this License. Each version will be given a
+ distinguishing version number.
+
+10.2. Effect of New Versions
+
+ You may distribute the Covered Software under the terms of the version of
+ the License under which You originally received the Covered Software, or
+ under the terms of any subsequent version published by the license
+ steward.
+
+10.3. Modified Versions
+
+ If you create software not governed by this License, and you want to
+ create a new license for such software, you may create and use a modified
+ version of this License if you rename the license and remove any
+ references to the name of the license steward (except to note that such
+ modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses
+ If You choose to distribute Source Code Form that is Incompatible With
+ Secondary Licenses under the terms of this version of the License, the
+ notice described in Exhibit B of this License must be attached.
+
+Exhibit A - Source Code Form License Notice
+
+ This Source Code Form is subject to the
+ terms of the Mozilla Public License, v.
+ 2.0. If a copy of the MPL was not
+ distributed with this file, You can
+ obtain one at
+ http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular file, then
+You may include the notice in a location (such as a LICENSE file in a relevant
+directory) where a recipient would be likely to look for such a notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - “Incompatible With Secondary Licenses” Notice
+
+ This Source Code Form is “Incompatible
+ With Secondary Licenses”, as defined by
+ the Mozilla Public License, v. 2.0.
--- /dev/null
+# Go Plugin System over RPC
+
+`go-plugin` is a Go (golang) plugin system over RPC. It is the plugin system
+that has been in use by HashiCorp tooling for over 4 years. While initially
+created for [Packer](https://www.packer.io), it is additionally in use by
+[Terraform](https://www.terraform.io), [Nomad](https://www.nomadproject.io), and
+[Vault](https://www.vaultproject.io).
+
+While the plugin system is over RPC, it is currently only designed to work
+over a local [reliable] network. Plugins over a real network are not supported
+and will lead to unexpected behavior.
+
+This plugin system has been used on millions of machines across many different
+projects and has proven to be battle hardened and ready for production use.
+
+## Features
+
+The HashiCorp plugin system supports a number of features:
+
+**Plugins are Go interface implementations.** This makes writing and consuming
+plugins feel very natural. To a plugin author: you just implement an
+interface as if it were going to run in the same process. For a plugin user:
+you just use and call functions on an interface as if it were in the same
+process. This plugin system handles the communication in between.
+
+**Cross-language support.** Plugins can be written (and consumed) by
+almost every major language. This library supports serving plugins via
+[gRPC](http://www.grpc.io). gRPC-based plugins enable plugins to be written
+in any language.
+
+**Complex arguments and return values are supported.** This library
+provides APIs for handling complex arguments and return values such
+as interfaces, `io.Reader/Writer`, etc. We do this by giving you a library
+(`MuxBroker`) for creating new connections between the client/server to
+serve additional interfaces or transfer raw data.
+
+**Bidirectional communication.** Because the plugin system supports
+complex arguments, the host process can send it interface implementations
+and the plugin can call back into the host process.
+
+**Built-in Logging.** Any plugins that use the `log` standard library
+will have log data automatically sent to the host process. The host
+process will mirror this output prefixed with the path to the plugin
+binary. This makes debugging with plugins simple. If the host system
+uses [hclog](https://github.com/hashicorp/go-hclog) then the log data
+will be structured. If the plugin also uses hclog, logs from the plugin
+will be sent to the host hclog and be structured.
+
+**Protocol Versioning.** A very basic "protocol version" is supported that
+can be incremented to invalidate any previous plugins. This is useful when
+interface signatures are changing, protocol level changes are necessary,
+etc. When a protocol version is incompatible, a human friendly error
+message is shown to the end user.
+
+**Stdout/Stderr Syncing.** While plugins are subprocesses, they can continue
+to use stdout/stderr as usual and the output will get mirrored back to
+the host process. The host process can control what `io.Writer` these
+streams go to to prevent this from happening.
+
+**TTY Preservation.** Plugin subprocesses are connected to the identical
+stdin file descriptor as the host process, allowing software that requires
+a TTY to work. For example, a plugin can execute `ssh` and even though there
+are multiple subprocesses and RPC happening, it will look and act perfectly
+to the end user.
+
+**Host upgrade while a plugin is running.** Plugins can be "reattached"
+so that the host process can be upgraded while the plugin is still running.
+This requires the host/plugin to know this is possible and daemonize
+properly. `NewClient` takes a `ReattachConfig` to determine if and how to
+reattach.
+
+**Cryptographically Secure Plugins.** Plugins can be verified with an expected
+checksum and RPC communications can be configured to use TLS. The host process
+must be properly secured to protect this configuration.
+
+## Architecture
+
+The HashiCorp plugin system works by launching subprocesses and communicating
+over RPC (using standard `net/rpc` or [gRPC](http://www.grpc.io)). A single
+connection is made between any plugin and the host process. For net/rpc-based
+plugins, we use a [connection multiplexing](https://github.com/hashicorp/yamux)
+library to multiplex any other connections on top. For gRPC-based plugins,
+the HTTP2 protocol handles multiplexing.
+
+This architecture has a number of benefits:
+
+ * Plugins can't crash your host process: A panic in a plugin doesn't
+ panic the plugin user.
+
+ * Plugins are very easy to write: just write a Go application and `go build`.
+ Or use any other language to write a gRPC server with a tiny amount of
+ boilerplate to support go-plugin.
+
+ * Plugins are very easy to install: just put the binary in a location where
+ the host will find it (depends on the host but this library also provides
+ helpers), and the plugin host handles the rest.
+
+ * Plugins can be relatively secure: The plugin only has access to the
+ interfaces and args given to it, not to the entire memory space of the
+ process. Additionally, go-plugin can communicate with the plugin over
+ TLS.
+
+## Usage
+
+To use the plugin system, you must take the following steps. These are
+high-level steps that must be done. Examples are available in the
+`examples/` directory.
+
+ 1. Choose the interface(s) you want to expose for plugins.
+
+ 2. For each interface, implement an implementation of that interface
+ that communicates over a `net/rpc` connection or other a
+ [gRPC](http://www.grpc.io) connection or both. You'll have to implement
+ both a client and server implementation.
+
+ 3. Create a `Plugin` implementation that knows how to create the RPC
+ client/server for a given plugin type.
+
+ 4. Plugin authors call `plugin.Serve` to serve a plugin from the
+ `main` function.
+
+ 5. Plugin users use `plugin.Client` to launch a subprocess and request
+ an interface implementation over RPC.
+
+That's it! In practice, step 2 is the most tedious and time consuming step.
+Even so, it isn't very difficult and you can see examples in the `examples/`
+directory as well as throughout our various open source projects.
+
+For complete API documentation, see [GoDoc](https://godoc.org/github.com/hashicorp/go-plugin).
+
+## Roadmap
+
+Our plugin system is constantly evolving. As we use the plugin system for
+new projects or for new features in existing projects, we constantly find
+improvements we can make.
+
+At this point in time, the roadmap for the plugin system is:
+
+**Semantic Versioning.** Plugins will be able to implement a semantic version.
+This plugin system will give host processes a system for constraining
+versions. This is in addition to the protocol versioning already present
+which is more for larger underlying changes.
+
+**Plugin fetching.** We will integrate with [go-getter](https://github.com/hashicorp/go-getter)
+to support automatic download + install of plugins. Paired with cryptographically
+secure plugins (above), we can make this a safe operation for an amazing
+user experience.
+
+## What About Shared Libraries?
+
+When we started using plugins (late 2012, early 2013), plugins over RPC
+were the only option since Go didn't support dynamic library loading. Today,
+Go still doesn't support dynamic library loading, but they do intend to.
+Since 2012, our plugin system has stabilized from millions of users using it,
+and has many benefits we've come to value greatly.
+
+For example, we intend to use this plugin system in
+[Vault](https://www.vaultproject.io), and dynamic library loading will
+simply never be acceptable in Vault for security reasons. That is an extreme
+example, but we believe our library system has more upsides than downsides
+over dynamic library loading and since we've had it built and tested for years,
+we'll likely continue to use it.
+
+Shared libraries have one major advantage over our system which is much
+higher performance. In real world scenarios across our various tools,
+we've never required any more performance out of our plugin system and it
+has seen very high throughput, so this isn't a concern for us at the moment.
+
--- /dev/null
+package plugin
+
+import (
+ "bufio"
+ "context"
+ "crypto/subtle"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "io/ioutil"
+ "net"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+ "unicode"
+
+ hclog "github.com/hashicorp/go-hclog"
+)
+
+// If this is 1, then we've called CleanupClients. This can be used
+// by plugin RPC implementations to change error behavior since you
+// can expected network connection errors at this point. This should be
+// read by using sync/atomic.
+var Killed uint32 = 0
+
+// This is a slice of the "managed" clients which are cleaned up when
+// calling Cleanup
+var managedClients = make([]*Client, 0, 5)
+var managedClientsLock sync.Mutex
+
+// Error types
+var (
+ // ErrProcessNotFound is returned when a client is instantiated to
+ // reattach to an existing process and it isn't found.
+ ErrProcessNotFound = errors.New("Reattachment process not found")
+
+ // ErrChecksumsDoNotMatch is returned when binary's checksum doesn't match
+ // the one provided in the SecureConfig.
+ ErrChecksumsDoNotMatch = errors.New("checksums did not match")
+
+ // ErrSecureNoChecksum is returned when an empty checksum is provided to the
+ // SecureConfig.
+ ErrSecureConfigNoChecksum = errors.New("no checksum provided")
+
+ // ErrSecureNoHash is returned when a nil Hash object is provided to the
+ // SecureConfig.
+ ErrSecureConfigNoHash = errors.New("no hash implementation provided")
+
+ // ErrSecureConfigAndReattach is returned when both Reattach and
+ // SecureConfig are set.
+ ErrSecureConfigAndReattach = errors.New("only one of Reattach or SecureConfig can be set")
+)
+
+// Client handles the lifecycle of a plugin application. It launches
+// plugins, connects to them, dispenses interface implementations, and handles
+// killing the process.
+//
+// Plugin hosts should use one Client for each plugin executable. To
+// dispense a plugin type, use the `Client.Client` function, and then
+// cal `Dispense`. This awkward API is mostly historical but is used to split
+// the client that deals with subprocess management and the client that
+// does RPC management.
+//
+// See NewClient and ClientConfig for using a Client.
+type Client struct {
+ config *ClientConfig
+ exited bool
+ doneLogging chan struct{}
+ l sync.Mutex
+ address net.Addr
+ process *os.Process
+ client ClientProtocol
+ protocol Protocol
+ logger hclog.Logger
+ doneCtx context.Context
+}
+
+// ClientConfig is the configuration used to initialize a new
+// plugin client. After being used to initialize a plugin client,
+// that configuration must not be modified again.
+type ClientConfig struct {
+ // HandshakeConfig is the configuration that must match servers.
+ HandshakeConfig
+
+ // Plugins are the plugins that can be consumed.
+ Plugins map[string]Plugin
+
+ // One of the following must be set, but not both.
+ //
+ // Cmd is the unstarted subprocess for starting the plugin. If this is
+ // set, then the Client starts the plugin process on its own and connects
+ // to it.
+ //
+ // Reattach is configuration for reattaching to an existing plugin process
+ // that is already running. This isn't common.
+ Cmd *exec.Cmd
+ Reattach *ReattachConfig
+
+ // SecureConfig is configuration for verifying the integrity of the
+ // executable. It can not be used with Reattach.
+ SecureConfig *SecureConfig
+
+ // TLSConfig is used to enable TLS on the RPC client.
+ TLSConfig *tls.Config
+
+ // Managed represents if the client should be managed by the
+ // plugin package or not. If true, then by calling CleanupClients,
+ // it will automatically be cleaned up. Otherwise, the client
+ // user is fully responsible for making sure to Kill all plugin
+ // clients. By default the client is _not_ managed.
+ Managed bool
+
+ // The minimum and maximum port to use for communicating with
+ // the subprocess. If not set, this defaults to 10,000 and 25,000
+ // respectively.
+ MinPort, MaxPort uint
+
+ // StartTimeout is the timeout to wait for the plugin to say it
+ // has started successfully.
+ StartTimeout time.Duration
+
+ // If non-nil, then the stderr of the client will be written to here
+ // (as well as the log). This is the original os.Stderr of the subprocess.
+ // This isn't the output of synced stderr.
+ Stderr io.Writer
+
+ // SyncStdout, SyncStderr can be set to override the
+ // respective os.Std* values in the plugin. Care should be taken to
+ // avoid races here. If these are nil, then this will automatically be
+ // hooked up to os.Stdin, Stdout, and Stderr, respectively.
+ //
+ // If the default values (nil) are used, then this package will not
+ // sync any of these streams.
+ SyncStdout io.Writer
+ SyncStderr io.Writer
+
+ // AllowedProtocols is a list of allowed protocols. If this isn't set,
+ // then only netrpc is allowed. This is so that older go-plugin systems
+ // can show friendly errors if they see a plugin with an unknown
+ // protocol.
+ //
+ // By setting this, you can cause an error immediately on plugin start
+ // if an unsupported protocol is used with a good error message.
+ //
+ // If this isn't set at all (nil value), then only net/rpc is accepted.
+ // This is done for legacy reasons. You must explicitly opt-in to
+ // new protocols.
+ AllowedProtocols []Protocol
+
+ // Logger is the logger that the client will used. If none is provided,
+ // it will default to hclog's default logger.
+ Logger hclog.Logger
+}
+
+// ReattachConfig is used to configure a client to reattach to an
+// already-running plugin process. You can retrieve this information by
+// calling ReattachConfig on Client.
+type ReattachConfig struct {
+ Protocol Protocol
+ Addr net.Addr
+ Pid int
+}
+
+// SecureConfig is used to configure a client to verify the integrity of an
+// executable before running. It does this by verifying the checksum is
+// expected. Hash is used to specify the hashing method to use when checksumming
+// the file. The configuration is verified by the client by calling the
+// SecureConfig.Check() function.
+//
+// The host process should ensure the checksum was provided by a trusted and
+// authoritative source. The binary should be installed in such a way that it
+// can not be modified by an unauthorized user between the time of this check
+// and the time of execution.
+type SecureConfig struct {
+ Checksum []byte
+ Hash hash.Hash
+}
+
+// Check takes the filepath to an executable and returns true if the checksum of
+// the file matches the checksum provided in the SecureConfig.
+func (s *SecureConfig) Check(filePath string) (bool, error) {
+ if len(s.Checksum) == 0 {
+ return false, ErrSecureConfigNoChecksum
+ }
+
+ if s.Hash == nil {
+ return false, ErrSecureConfigNoHash
+ }
+
+ file, err := os.Open(filePath)
+ if err != nil {
+ return false, err
+ }
+ defer file.Close()
+
+ _, err = io.Copy(s.Hash, file)
+ if err != nil {
+ return false, err
+ }
+
+ sum := s.Hash.Sum(nil)
+
+ return subtle.ConstantTimeCompare(sum, s.Checksum) == 1, nil
+}
+
+// This makes sure all the managed subprocesses are killed and properly
+// logged. This should be called before the parent process running the
+// plugins exits.
+//
+// This must only be called _once_.
+func CleanupClients() {
+ // Set the killed to true so that we don't get unexpected panics
+ atomic.StoreUint32(&Killed, 1)
+
+ // Kill all the managed clients in parallel and use a WaitGroup
+ // to wait for them all to finish up.
+ var wg sync.WaitGroup
+ managedClientsLock.Lock()
+ for _, client := range managedClients {
+ wg.Add(1)
+
+ go func(client *Client) {
+ client.Kill()
+ wg.Done()
+ }(client)
+ }
+ managedClientsLock.Unlock()
+
+ wg.Wait()
+}
+
+// Creates a new plugin client which manages the lifecycle of an external
+// plugin and gets the address for the RPC connection.
+//
+// The client must be cleaned up at some point by calling Kill(). If
+// the client is a managed client (created with NewManagedClient) you
+// can just call CleanupClients at the end of your program and they will
+// be properly cleaned.
+func NewClient(config *ClientConfig) (c *Client) {
+ if config.MinPort == 0 && config.MaxPort == 0 {
+ config.MinPort = 10000
+ config.MaxPort = 25000
+ }
+
+ if config.StartTimeout == 0 {
+ config.StartTimeout = 1 * time.Minute
+ }
+
+ if config.Stderr == nil {
+ config.Stderr = ioutil.Discard
+ }
+
+ if config.SyncStdout == nil {
+ config.SyncStdout = ioutil.Discard
+ }
+ if config.SyncStderr == nil {
+ config.SyncStderr = ioutil.Discard
+ }
+
+ if config.AllowedProtocols == nil {
+ config.AllowedProtocols = []Protocol{ProtocolNetRPC}
+ }
+
+ if config.Logger == nil {
+ config.Logger = hclog.New(&hclog.LoggerOptions{
+ Output: hclog.DefaultOutput,
+ Level: hclog.Trace,
+ Name: "plugin",
+ })
+ }
+
+ c = &Client{
+ config: config,
+ logger: config.Logger,
+ }
+ if config.Managed {
+ managedClientsLock.Lock()
+ managedClients = append(managedClients, c)
+ managedClientsLock.Unlock()
+ }
+
+ return
+}
+
+// Client returns the protocol client for this connection.
+//
+// Subsequent calls to this will return the same client.
+func (c *Client) Client() (ClientProtocol, error) {
+ _, err := c.Start()
+ if err != nil {
+ return nil, err
+ }
+
+ c.l.Lock()
+ defer c.l.Unlock()
+
+ if c.client != nil {
+ return c.client, nil
+ }
+
+ switch c.protocol {
+ case ProtocolNetRPC:
+ c.client, err = newRPCClient(c)
+
+ case ProtocolGRPC:
+ c.client, err = newGRPCClient(c.doneCtx, c)
+
+ default:
+ return nil, fmt.Errorf("unknown server protocol: %s", c.protocol)
+ }
+
+ if err != nil {
+ c.client = nil
+ return nil, err
+ }
+
+ return c.client, nil
+}
+
+// Tells whether or not the underlying process has exited.
+func (c *Client) Exited() bool {
+ c.l.Lock()
+ defer c.l.Unlock()
+ return c.exited
+}
+
+// End the executing subprocess (if it is running) and perform any cleanup
+// tasks necessary such as capturing any remaining logs and so on.
+//
+// This method blocks until the process successfully exits.
+//
+// This method can safely be called multiple times.
+func (c *Client) Kill() {
+ // Grab a lock to read some private fields.
+ c.l.Lock()
+ process := c.process
+ addr := c.address
+ doneCh := c.doneLogging
+ c.l.Unlock()
+
+ // If there is no process, we never started anything. Nothing to kill.
+ if process == nil {
+ return
+ }
+
+ // We need to check for address here. It is possible that the plugin
+ // started (process != nil) but has no address (addr == nil) if the
+ // plugin failed at startup. If we do have an address, we need to close
+ // the plugin net connections.
+ graceful := false
+ if addr != nil {
+ // Close the client to cleanly exit the process.
+ client, err := c.Client()
+ if err == nil {
+ err = client.Close()
+
+ // If there is no error, then we attempt to wait for a graceful
+ // exit. If there was an error, we assume that graceful cleanup
+ // won't happen and just force kill.
+ graceful = err == nil
+ if err != nil {
+ // If there was an error just log it. We're going to force
+ // kill in a moment anyways.
+ c.logger.Warn("error closing client during Kill", "err", err)
+ }
+ }
+ }
+
+ // If we're attempting a graceful exit, then we wait for a short period
+ // of time to allow that to happen. To wait for this we just wait on the
+ // doneCh which would be closed if the process exits.
+ if graceful {
+ select {
+ case <-doneCh:
+ return
+ case <-time.After(250 * time.Millisecond):
+ }
+ }
+
+ // If graceful exiting failed, just kill it
+ process.Kill()
+
+ // Wait for the client to finish logging so we have a complete log
+ <-doneCh
+}
+
+// Starts the underlying subprocess, communicating with it to negotiate
+// a port for RPC connections, and returning the address to connect via RPC.
+//
+// This method is safe to call multiple times. Subsequent calls have no effect.
+// Once a client has been started once, it cannot be started again, even if
+// it was killed.
+func (c *Client) Start() (addr net.Addr, err error) {
+ c.l.Lock()
+ defer c.l.Unlock()
+
+ if c.address != nil {
+ return c.address, nil
+ }
+
+ // If one of cmd or reattach isn't set, then it is an error. We wrap
+ // this in a {} for scoping reasons, and hopeful that the escape
+ // analysis will pop the stock here.
+ {
+ cmdSet := c.config.Cmd != nil
+ attachSet := c.config.Reattach != nil
+ secureSet := c.config.SecureConfig != nil
+ if cmdSet == attachSet {
+ return nil, fmt.Errorf("Only one of Cmd or Reattach must be set")
+ }
+
+ if secureSet && attachSet {
+ return nil, ErrSecureConfigAndReattach
+ }
+ }
+
+ // Create the logging channel for when we kill
+ c.doneLogging = make(chan struct{})
+ // Create a context for when we kill
+ var ctxCancel context.CancelFunc
+ c.doneCtx, ctxCancel = context.WithCancel(context.Background())
+
+ if c.config.Reattach != nil {
+ // Verify the process still exists. If not, then it is an error
+ p, err := os.FindProcess(c.config.Reattach.Pid)
+ if err != nil {
+ return nil, err
+ }
+
+ // Attempt to connect to the addr since on Unix systems FindProcess
+ // doesn't actually return an error if it can't find the process.
+ conn, err := net.Dial(
+ c.config.Reattach.Addr.Network(),
+ c.config.Reattach.Addr.String())
+ if err != nil {
+ p.Kill()
+ return nil, ErrProcessNotFound
+ }
+ conn.Close()
+
+ // Goroutine to mark exit status
+ go func(pid int) {
+ // Wait for the process to die
+ pidWait(pid)
+
+ // Log so we can see it
+ c.logger.Debug("reattached plugin process exited")
+
+ // Mark it
+ c.l.Lock()
+ defer c.l.Unlock()
+ c.exited = true
+
+ // Close the logging channel since that doesn't work on reattach
+ close(c.doneLogging)
+
+ // Cancel the context
+ ctxCancel()
+ }(p.Pid)
+
+ // Set the address and process
+ c.address = c.config.Reattach.Addr
+ c.process = p
+ c.protocol = c.config.Reattach.Protocol
+ if c.protocol == "" {
+ // Default the protocol to net/rpc for backwards compatibility
+ c.protocol = ProtocolNetRPC
+ }
+
+ return c.address, nil
+ }
+
+ env := []string{
+ fmt.Sprintf("%s=%s", c.config.MagicCookieKey, c.config.MagicCookieValue),
+ fmt.Sprintf("PLUGIN_MIN_PORT=%d", c.config.MinPort),
+ fmt.Sprintf("PLUGIN_MAX_PORT=%d", c.config.MaxPort),
+ }
+
+ stdout_r, stdout_w := io.Pipe()
+ stderr_r, stderr_w := io.Pipe()
+
+ cmd := c.config.Cmd
+ cmd.Env = append(cmd.Env, os.Environ()...)
+ cmd.Env = append(cmd.Env, env...)
+ cmd.Stdin = os.Stdin
+ cmd.Stderr = stderr_w
+ cmd.Stdout = stdout_w
+
+ if c.config.SecureConfig != nil {
+ if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil {
+ return nil, fmt.Errorf("error verifying checksum: %s", err)
+ } else if !ok {
+ return nil, ErrChecksumsDoNotMatch
+ }
+ }
+
+ c.logger.Debug("starting plugin", "path", cmd.Path, "args", cmd.Args)
+ err = cmd.Start()
+ if err != nil {
+ return
+ }
+
+ // Set the process
+ c.process = cmd.Process
+
+ // Make sure the command is properly cleaned up if there is an error
+ defer func() {
+ r := recover()
+
+ if err != nil || r != nil {
+ cmd.Process.Kill()
+ }
+
+ if r != nil {
+ panic(r)
+ }
+ }()
+
+ // Start goroutine to wait for process to exit
+ exitCh := make(chan struct{})
+ go func() {
+ // Make sure we close the write end of our stderr/stdout so
+ // that the readers send EOF properly.
+ defer stderr_w.Close()
+ defer stdout_w.Close()
+
+ // Wait for the command to end.
+ cmd.Wait()
+
+ // Log and make sure to flush the logs write away
+ c.logger.Debug("plugin process exited", "path", cmd.Path)
+ os.Stderr.Sync()
+
+ // Mark that we exited
+ close(exitCh)
+
+ // Cancel the context, marking that we exited
+ ctxCancel()
+
+ // Set that we exited, which takes a lock
+ c.l.Lock()
+ defer c.l.Unlock()
+ c.exited = true
+ }()
+
+ // Start goroutine that logs the stderr
+ go c.logStderr(stderr_r)
+
+ // Start a goroutine that is going to be reading the lines
+ // out of stdout
+ linesCh := make(chan []byte)
+ go func() {
+ defer close(linesCh)
+
+ buf := bufio.NewReader(stdout_r)
+ for {
+ line, err := buf.ReadBytes('\n')
+ if line != nil {
+ linesCh <- line
+ }
+
+ if err == io.EOF {
+ return
+ }
+ }
+ }()
+
+ // Make sure after we exit we read the lines from stdout forever
+ // so they don't block since it is an io.Pipe
+ defer func() {
+ go func() {
+ for _ = range linesCh {
+ }
+ }()
+ }()
+
+ // Some channels for the next step
+ timeout := time.After(c.config.StartTimeout)
+
+ // Start looking for the address
+ c.logger.Debug("waiting for RPC address", "path", cmd.Path)
+ select {
+ case <-timeout:
+ err = errors.New("timeout while waiting for plugin to start")
+ case <-exitCh:
+ err = errors.New("plugin exited before we could connect")
+ case lineBytes := <-linesCh:
+ // Trim the line and split by "|" in order to get the parts of
+ // the output.
+ line := strings.TrimSpace(string(lineBytes))
+ parts := strings.SplitN(line, "|", 6)
+ if len(parts) < 4 {
+ err = fmt.Errorf(
+ "Unrecognized remote plugin message: %s\n\n"+
+ "This usually means that the plugin is either invalid or simply\n"+
+ "needs to be recompiled to support the latest protocol.", line)
+ return
+ }
+
+ // Check the core protocol. Wrapped in a {} for scoping.
+ {
+ var coreProtocol int64
+ coreProtocol, err = strconv.ParseInt(parts[0], 10, 0)
+ if err != nil {
+ err = fmt.Errorf("Error parsing core protocol version: %s", err)
+ return
+ }
+
+ if int(coreProtocol) != CoreProtocolVersion {
+ err = fmt.Errorf("Incompatible core API version with plugin. "+
+ "Plugin version: %s, Core version: %d\n\n"+
+ "To fix this, the plugin usually only needs to be recompiled.\n"+
+ "Please report this to the plugin author.", parts[0], CoreProtocolVersion)
+ return
+ }
+ }
+
+ // Parse the protocol version
+ var protocol int64
+ protocol, err = strconv.ParseInt(parts[1], 10, 0)
+ if err != nil {
+ err = fmt.Errorf("Error parsing protocol version: %s", err)
+ return
+ }
+
+ // Test the API version
+ if uint(protocol) != c.config.ProtocolVersion {
+ err = fmt.Errorf("Incompatible API version with plugin. "+
+ "Plugin version: %s, Core version: %d", parts[1], c.config.ProtocolVersion)
+ return
+ }
+
+ switch parts[2] {
+ case "tcp":
+ addr, err = net.ResolveTCPAddr("tcp", parts[3])
+ case "unix":
+ addr, err = net.ResolveUnixAddr("unix", parts[3])
+ default:
+ err = fmt.Errorf("Unknown address type: %s", parts[3])
+ }
+
+ // If we have a server type, then record that. We default to net/rpc
+ // for backwards compatibility.
+ c.protocol = ProtocolNetRPC
+ if len(parts) >= 5 {
+ c.protocol = Protocol(parts[4])
+ }
+
+ found := false
+ for _, p := range c.config.AllowedProtocols {
+ if p == c.protocol {
+ found = true
+ break
+ }
+ }
+ if !found {
+ err = fmt.Errorf("Unsupported plugin protocol %q. Supported: %v",
+ c.protocol, c.config.AllowedProtocols)
+ return
+ }
+
+ }
+
+ c.address = addr
+ return
+}
+
+// ReattachConfig returns the information that must be provided to NewClient
+// to reattach to the plugin process that this client started. This is
+// useful for plugins that detach from their parent process.
+//
+// If this returns nil then the process hasn't been started yet. Please
+// call Start or Client before calling this.
+func (c *Client) ReattachConfig() *ReattachConfig {
+ c.l.Lock()
+ defer c.l.Unlock()
+
+ if c.address == nil {
+ return nil
+ }
+
+ if c.config.Cmd != nil && c.config.Cmd.Process == nil {
+ return nil
+ }
+
+ // If we connected via reattach, just return the information as-is
+ if c.config.Reattach != nil {
+ return c.config.Reattach
+ }
+
+ return &ReattachConfig{
+ Protocol: c.protocol,
+ Addr: c.address,
+ Pid: c.config.Cmd.Process.Pid,
+ }
+}
+
+// Protocol returns the protocol of server on the remote end. This will
+// start the plugin process if it isn't already started. Errors from
+// starting the plugin are surpressed and ProtocolInvalid is returned. It
+// is recommended you call Start explicitly before calling Protocol to ensure
+// no errors occur.
+func (c *Client) Protocol() Protocol {
+ _, err := c.Start()
+ if err != nil {
+ return ProtocolInvalid
+ }
+
+ return c.protocol
+}
+
+func netAddrDialer(addr net.Addr) func(string, time.Duration) (net.Conn, error) {
+ return func(_ string, _ time.Duration) (net.Conn, error) {
+ // Connect to the client
+ conn, err := net.Dial(addr.Network(), addr.String())
+ if err != nil {
+ return nil, err
+ }
+ if tcpConn, ok := conn.(*net.TCPConn); ok {
+ // Make sure to set keep alive so that the connection doesn't die
+ tcpConn.SetKeepAlive(true)
+ }
+
+ return conn, nil
+ }
+}
+
+// dialer is compatible with grpc.WithDialer and creates the connection
+// to the plugin.
+func (c *Client) dialer(_ string, timeout time.Duration) (net.Conn, error) {
+ conn, err := netAddrDialer(c.address)("", timeout)
+ if err != nil {
+ return nil, err
+ }
+
+ // If we have a TLS config we wrap our connection. We only do this
+ // for net/rpc since gRPC uses its own mechanism for TLS.
+ if c.protocol == ProtocolNetRPC && c.config.TLSConfig != nil {
+ conn = tls.Client(conn, c.config.TLSConfig)
+ }
+
+ return conn, nil
+}
+
+func (c *Client) logStderr(r io.Reader) {
+ bufR := bufio.NewReader(r)
+ l := c.logger.Named(filepath.Base(c.config.Cmd.Path))
+
+ for {
+ line, err := bufR.ReadString('\n')
+ if line != "" {
+ c.config.Stderr.Write([]byte(line))
+ line = strings.TrimRightFunc(line, unicode.IsSpace)
+
+ entry, err := parseJSON(line)
+ // If output is not JSON format, print directly to Debug
+ if err != nil {
+ l.Debug(line)
+ } else {
+ out := flattenKVPairs(entry.KVPairs)
+
+ out = append(out, "timestamp", entry.Timestamp.Format(hclog.TimeFormat))
+ switch hclog.LevelFromString(entry.Level) {
+ case hclog.Trace:
+ l.Trace(entry.Message, out...)
+ case hclog.Debug:
+ l.Debug(entry.Message, out...)
+ case hclog.Info:
+ l.Info(entry.Message, out...)
+ case hclog.Warn:
+ l.Warn(entry.Message, out...)
+ case hclog.Error:
+ l.Error(entry.Message, out...)
+ }
+ }
+ }
+
+ if err == io.EOF {
+ break
+ }
+ }
+
+ // Flag that we've completed logging for others
+ close(c.doneLogging)
+}
--- /dev/null
+// +build !windows
+
+package plugin
+
+import (
+ "os"
+ "reflect"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func TestClient_testInterfaceReattach(t *testing.T) {
+ // Setup the process for daemonization
+ process := helperProcess("test-interface-daemon")
+ if process.SysProcAttr == nil {
+ process.SysProcAttr = &syscall.SysProcAttr{}
+ }
+ process.SysProcAttr.Setsid = true
+ syscall.Umask(0)
+
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Start it so we can get the reattach info
+ if _, err := c.Start(); err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // New client with reattach info
+ reattach := c.ReattachConfig()
+ if reattach == nil {
+ c.Kill()
+ t.Fatal("reattach config should be non-nil")
+ }
+
+ // Find the process and defer a kill so we know it is gone
+ p, err := os.FindProcess(reattach.Pid)
+ if err != nil {
+ c.Kill()
+ t.Fatalf("couldn't find process: %s", err)
+ }
+ defer p.Kill()
+
+ // Reattach
+ c = NewClient(&ClientConfig{
+ Reattach: reattach,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Start shouldn't error
+ if _, err := c.Start(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // It should still be alive
+ time.Sleep(1 * time.Second)
+ if c.Exited() {
+ t.Fatal("should not be exited")
+ }
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Test the resulting reattach config
+ reattach2 := c.ReattachConfig()
+ if reattach2 == nil {
+ t.Fatal("reattach from reattached should not be nil")
+ }
+ if !reflect.DeepEqual(reattach, reattach2) {
+ t.Fatalf("bad: %#v", reattach)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
--- /dev/null
+package plugin
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "io"
+ "io/ioutil"
+ "net"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ hclog "github.com/hashicorp/go-hclog"
+)
+
+func TestClient(t *testing.T) {
+ process := helperProcess("mock")
+ c := NewClient(&ClientConfig{Cmd: process, HandshakeConfig: testHandshake})
+ defer c.Kill()
+
+ // Test that it parses the proper address
+ addr, err := c.Start()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ if addr.Network() != "tcp" {
+ t.Fatalf("bad: %#v", addr)
+ }
+
+ if addr.String() != ":1234" {
+ t.Fatalf("bad: %#v", addr)
+ }
+
+ // Test that it exits properly if killed
+ c.Kill()
+
+ if process.ProcessState == nil {
+ t.Fatal("should have process state")
+ }
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+// This tests a bug where Kill would start
+func TestClient_killStart(t *testing.T) {
+ // Create a temporary dir to store the result file
+ td, err := ioutil.TempDir("", "plugin")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ defer os.RemoveAll(td)
+
+ // Start the client
+ path := filepath.Join(td, "booted")
+ process := helperProcess("bad-version", path)
+ c := NewClient(&ClientConfig{Cmd: process, HandshakeConfig: testHandshake})
+ defer c.Kill()
+
+ // Verify our path doesn't exist
+ if _, err := os.Stat(path); err == nil || !os.IsNotExist(err) {
+ t.Fatalf("bad: %s", err)
+ }
+
+ // Test that it parses the proper address
+ if _, err := c.Start(); err == nil {
+ t.Fatal("expected error")
+ }
+
+ // Verify we started
+ if _, err := os.Stat(path); err != nil {
+ t.Fatalf("bad: %s", err)
+ }
+ if err := os.Remove(path); err != nil {
+ t.Fatalf("bad: %s", err)
+ }
+
+ // Test that Kill does nothing really
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+
+ if process.ProcessState == nil {
+ t.Fatal("should have no process state")
+ }
+
+ // Verify our path doesn't exist
+ if _, err := os.Stat(path); err == nil || !os.IsNotExist(err) {
+ t.Fatalf("bad: %s", err)
+ }
+}
+
+func TestClient_testCleanup(t *testing.T) {
+ // Create a temporary dir to store the result file
+ td, err := ioutil.TempDir("", "plugin")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ defer os.RemoveAll(td)
+
+ // Create a path that the helper process will write on cleanup
+ path := filepath.Join(td, "output")
+
+ // Test the cleanup
+ process := helperProcess("cleanup", path)
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Grab the client so the process starts
+ if _, err := c.Client(); err != nil {
+ c.Kill()
+ t.Fatalf("err: %s", err)
+ }
+
+ // Kill it gracefully
+ c.Kill()
+
+ // Test for the file
+ if _, err := os.Stat(path); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+}
+
+func TestClient_testInterface(t *testing.T) {
+ process := helperProcess("test-interface")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+func TestClient_grpc_servercrash(t *testing.T) {
+ process := helperProcess("test-grpc")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ AllowedProtocols: []Protocol{ProtocolGRPC},
+ })
+ defer c.Kill()
+
+ if _, err := c.Start(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ if v := c.Protocol(); v != ProtocolGRPC {
+ t.Fatalf("bad: %s", v)
+ }
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ _, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ c.process.Kill()
+
+ select {
+ case <-c.doneCtx.Done():
+ case <-time.After(time.Second * 2):
+ t.Fatal("Context was not closed")
+ }
+}
+
+func TestClient_grpc(t *testing.T) {
+ process := helperProcess("test-grpc")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ AllowedProtocols: []Protocol{ProtocolGRPC},
+ })
+ defer c.Kill()
+
+ if _, err := c.Start(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ if v := c.Protocol(); v != ProtocolGRPC {
+ t.Fatalf("bad: %s", v)
+ }
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+func TestClient_grpcNotAllowed(t *testing.T) {
+ process := helperProcess("test-grpc")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+ defer c.Kill()
+
+ if _, err := c.Start(); err == nil {
+ t.Fatal("should error")
+ }
+}
+
+func TestClient_cmdAndReattach(t *testing.T) {
+ config := &ClientConfig{
+ Cmd: helperProcess("start-timeout"),
+ Reattach: &ReattachConfig{},
+ }
+
+ c := NewClient(config)
+ defer c.Kill()
+
+ _, err := c.Start()
+ if err == nil {
+ t.Fatal("err should not be nil")
+ }
+}
+
+func TestClient_reattach(t *testing.T) {
+ process := helperProcess("test-interface")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ _, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Get the reattach configuration
+ reattach := c.ReattachConfig()
+
+ // Create a new client
+ c = NewClient(&ClientConfig{
+ Reattach: reattach,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+func TestClient_reattachNoProtocol(t *testing.T) {
+ process := helperProcess("test-interface")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ _, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Get the reattach configuration
+ reattach := c.ReattachConfig()
+ reattach.Protocol = ""
+
+ // Create a new client
+ c = NewClient(&ClientConfig{
+ Reattach: reattach,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+func TestClient_reattachGRPC(t *testing.T) {
+ process := helperProcess("test-grpc")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ AllowedProtocols: []Protocol{ProtocolGRPC},
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ _, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Get the reattach configuration
+ reattach := c.ReattachConfig()
+
+ // Create a new client
+ c = NewClient(&ClientConfig{
+ Reattach: reattach,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ AllowedProtocols: []Protocol{ProtocolGRPC},
+ })
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+func TestClient_reattachNotFound(t *testing.T) {
+ // Find a bad pid
+ var pid int = 5000
+ for i := pid; i < 32000; i++ {
+ if _, err := os.FindProcess(i); err != nil {
+ pid = i
+ break
+ }
+ }
+
+ // Addr that won't work
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ addr := l.Addr()
+ l.Close()
+
+ // Reattach
+ c := NewClient(&ClientConfig{
+ Reattach: &ReattachConfig{
+ Addr: addr,
+ Pid: pid,
+ },
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Start shouldn't error
+ if _, err := c.Start(); err == nil {
+ t.Fatal("should error")
+ } else if err != ErrProcessNotFound {
+ t.Fatalf("err: %s", err)
+ }
+}
+
+func TestClientStart_badVersion(t *testing.T) {
+ config := &ClientConfig{
+ Cmd: helperProcess("bad-version"),
+ StartTimeout: 50 * time.Millisecond,
+ HandshakeConfig: testHandshake,
+ }
+
+ c := NewClient(config)
+ defer c.Kill()
+
+ _, err := c.Start()
+ if err == nil {
+ t.Fatal("err should not be nil")
+ }
+}
+
+func TestClient_Start_Timeout(t *testing.T) {
+ config := &ClientConfig{
+ Cmd: helperProcess("start-timeout"),
+ StartTimeout: 50 * time.Millisecond,
+ HandshakeConfig: testHandshake,
+ }
+
+ c := NewClient(config)
+ defer c.Kill()
+
+ _, err := c.Start()
+ if err == nil {
+ t.Fatal("err should not be nil")
+ }
+}
+
+func TestClient_Stderr(t *testing.T) {
+ stderr := new(bytes.Buffer)
+ process := helperProcess("stderr")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ Stderr: stderr,
+ HandshakeConfig: testHandshake,
+ })
+ defer c.Kill()
+
+ if _, err := c.Start(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ for !c.Exited() {
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ if !strings.Contains(stderr.String(), "HELLO\n") {
+ t.Fatalf("bad log data: '%s'", stderr.String())
+ }
+
+ if !strings.Contains(stderr.String(), "WORLD\n") {
+ t.Fatalf("bad log data: '%s'", stderr.String())
+ }
+}
+
+func TestClient_StderrJSON(t *testing.T) {
+ stderr := new(bytes.Buffer)
+ process := helperProcess("stderr-json")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ Stderr: stderr,
+ HandshakeConfig: testHandshake,
+ })
+ defer c.Kill()
+
+ if _, err := c.Start(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ for !c.Exited() {
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ if !strings.Contains(stderr.String(), "[\"HELLO\"]\n") {
+ t.Fatalf("bad log data: '%s'", stderr.String())
+ }
+
+ if !strings.Contains(stderr.String(), "12345\n") {
+ t.Fatalf("bad log data: '%s'", stderr.String())
+ }
+}
+
+func TestClient_Stdin(t *testing.T) {
+ // Overwrite stdin for this test with a temporary file
+ tf, err := ioutil.TempFile("", "terraform")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ defer os.Remove(tf.Name())
+ defer tf.Close()
+
+ if _, err = tf.WriteString("hello"); err != nil {
+ t.Fatalf("error: %s", err)
+ }
+
+ if err = tf.Sync(); err != nil {
+ t.Fatalf("error: %s", err)
+ }
+
+ if _, err = tf.Seek(0, 0); err != nil {
+ t.Fatalf("error: %s", err)
+ }
+
+ oldStdin := os.Stdin
+ defer func() { os.Stdin = oldStdin }()
+ os.Stdin = tf
+
+ process := helperProcess("stdin")
+ c := NewClient(&ClientConfig{Cmd: process, HandshakeConfig: testHandshake})
+ defer c.Kill()
+
+ _, err = c.Start()
+ if err != nil {
+ t.Fatalf("error: %s", err)
+ }
+
+ for {
+ if c.Exited() {
+ break
+ }
+
+ time.Sleep(50 * time.Millisecond)
+ }
+
+ if !process.ProcessState.Success() {
+ t.Fatal("process didn't exit cleanly")
+ }
+}
+
+func TestClient_SecureConfig(t *testing.T) {
+ // Test failure case
+ secureConfig := &SecureConfig{
+ Checksum: []byte{'1'},
+ Hash: sha256.New(),
+ }
+ process := helperProcess("test-interface")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ SecureConfig: secureConfig,
+ })
+
+ // Grab the RPC client, should error
+ _, err := c.Client()
+ c.Kill()
+ if err != ErrChecksumsDoNotMatch {
+ t.Fatalf("err should be %s, got %s", ErrChecksumsDoNotMatch, err)
+ }
+
+ // Get the checksum of the executable
+ file, err := os.Open(os.Args[0])
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer file.Close()
+
+ hash := sha256.New()
+
+ _, err = io.Copy(hash, file)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ sum := hash.Sum(nil)
+
+ secureConfig = &SecureConfig{
+ Checksum: sum,
+ Hash: sha256.New(),
+ }
+
+ c = NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ SecureConfig: secureConfig,
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ _, err = c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+}
+
+func TestClient_TLS(t *testing.T) {
+ // Test failure case
+ process := helperProcess("test-interface-tls")
+ cBad := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+ defer cBad.Kill()
+
+ // Grab the RPC client
+ clientBad, err := cBad.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := clientBad.Dispense("test")
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+
+ cBad.Kill()
+
+ // Add TLS config to client
+ tlsConfig, err := helperTLSProvider()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ process = helperProcess("test-interface-tls")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ TLSConfig: tlsConfig,
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err = client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+func TestClient_TLS_grpc(t *testing.T) {
+ // Add TLS config to client
+ tlsConfig, err := helperTLSProvider()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ process := helperProcess("test-grpc-tls")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ TLSConfig: tlsConfig,
+ AllowedProtocols: []Protocol{ProtocolGRPC},
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
+
+func TestClient_secureConfigAndReattach(t *testing.T) {
+ config := &ClientConfig{
+ SecureConfig: &SecureConfig{},
+ Reattach: &ReattachConfig{},
+ }
+
+ c := NewClient(config)
+ defer c.Kill()
+
+ _, err := c.Start()
+ if err != ErrSecureConfigAndReattach {
+ t.Fatalf("err should not be %s, got %s", ErrSecureConfigAndReattach, err)
+ }
+}
+
+func TestClient_ping(t *testing.T) {
+ process := helperProcess("test-interface")
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+ defer c.Kill()
+
+ // Get the client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Ping, should work
+ if err := client.Ping(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Kill it
+ c.Kill()
+ if err := client.Ping(); err == nil {
+ t.Fatal("should error")
+ }
+}
+
+func TestClient_logger(t *testing.T) {
+ t.Run("net/rpc", func(t *testing.T) { testClient_logger(t, "netrpc") })
+ t.Run("grpc", func(t *testing.T) { testClient_logger(t, "grpc") })
+}
+
+func testClient_logger(t *testing.T, proto string) {
+ var buffer bytes.Buffer
+ mutex := new(sync.Mutex)
+ stderr := io.MultiWriter(os.Stderr, &buffer)
+ // Custom hclog.Logger
+ clientLogger := hclog.New(&hclog.LoggerOptions{
+ Name: "test-logger",
+ Level: hclog.Trace,
+ Output: stderr,
+ Mutex: mutex,
+ })
+
+ process := helperProcess("test-interface-logger-" + proto)
+ c := NewClient(&ClientConfig{
+ Cmd: process,
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ Logger: clientLogger,
+ AllowedProtocols: []Protocol{ProtocolNetRPC, ProtocolGRPC},
+ })
+ defer c.Kill()
+
+ // Grab the RPC client
+ client, err := c.Client()
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ // Grab the impl
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err should be nil, got %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ {
+ // Discard everything else, and capture the output we care about
+ mutex.Lock()
+ buffer.Reset()
+ mutex.Unlock()
+ impl.PrintKV("foo", "bar")
+ time.Sleep(100 * time.Millisecond)
+ mutex.Lock()
+ line, err := buffer.ReadString('\n')
+ mutex.Unlock()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.Contains(line, "foo=bar") {
+ t.Fatalf("bad: %q", line)
+ }
+ }
+
+ {
+ // Try an integer type
+ mutex.Lock()
+ buffer.Reset()
+ mutex.Unlock()
+ impl.PrintKV("foo", 12)
+ time.Sleep(100 * time.Millisecond)
+ mutex.Lock()
+ line, err := buffer.ReadString('\n')
+ mutex.Unlock()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.Contains(line, "foo=12") {
+ t.Fatalf("bad: %q", line)
+ }
+ }
+
+ // Kill it
+ c.Kill()
+
+ // Test that it knows it is exited
+ if !c.Exited() {
+ t.Fatal("should say client has exited")
+ }
+}
--- /dev/null
+package plugin
+
+import (
+ "path/filepath"
+)
+
+// Discover discovers plugins that are in a given directory.
+//
+// The directory doesn't need to be absolute. For example, "." will work fine.
+//
+// This currently assumes any file matching the glob is a plugin.
+// In the future this may be smarter about checking that a file is
+// executable and so on.
+//
+// TODO: test
+func Discover(glob, dir string) ([]string, error) {
+ var err error
+
+ // Make the directory absolute if it isn't already
+ if !filepath.IsAbs(dir) {
+ dir, err = filepath.Abs(dir)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return filepath.Glob(filepath.Join(dir, glob))
+}
--- /dev/null
+# go-plugin Documentation
+
+This directory contains documentation and guides for `go-plugin` and how
+to integrate it into your projects. It is assumed that you know _what_
+go-plugin is and _why_ you would want to use it. If not, please see the
+[README](https://github.com/hashicorp/go-plugin/blob/master/README.md).
+
+## Table of Contents
+
+**[Writing Plugins Without Go](https://github.com/hashicorp/go-plugin/blob/master/docs/guide-plugin-write-non-go.md).**
+This shows how to write a plugin using a programming language other than
+Go.
--- /dev/null
+# Writing Plugins Without Go
+
+This guide explains how to write a go-plugin compatible plugin using
+a programming language other than Go. go-plugin supports plugins using
+[gRPC](http://www.grpc.io). This makes it relatively simple to write plugins
+using other languages!
+
+Minimal knowledge about gRPC is assumed. We recommend reading the
+[gRPC Go Tutorial](http://www.grpc.io/docs/tutorials/basic/go.html). This
+alone is enough gRPC knowledge to continue.
+
+This guide will implement the kv example in Python.
+Full source code for the examples present in this guide
+[is available in the examples/grpc folder](https://github.com/hashicorp/go-plugin/tree/master/examples/grpc).
+
+## 1. Implement the Service
+
+The first step is to implement the gRPC server for the protocol buffers
+service that your plugin defines. This is a standard gRPC server.
+For the KV service, the service looks like this:
+
+```proto
+service KV {
+ rpc Get(GetRequest) returns (GetResponse);
+ rpc Put(PutRequest) returns (Empty);
+}
+```
+
+We can implement that using Python as easily as:
+
+```python
+class KVServicer(kv_pb2_grpc.KVServicer):
+ """Implementation of KV service."""
+
+ def Get(self, request, context):
+ filename = "kv_"+request.key
+ with open(filename, 'r') as f:
+ result = kv_pb2.GetResponse()
+ result.value = f.read()
+ return result
+
+ def Put(self, request, context):
+ filename = "kv_"+request.key
+ value = "{0}\n\nWritten from plugin-python".format(request.value)
+ with open(filename, 'w') as f:
+ f.write(value)
+
+ return kv_pb2.Empty()
+
+```
+
+Great! With that, we have a fully functioning implementation of the service.
+You can test this using standard gRPC testing mechanisms.
+
+## 2. Serve the Service
+
+Next, we need to create a gRPC server and serve the service we just made.
+
+In Python:
+
+```python
+# Make the server
+server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+
+# Add our service
+kv_pb2_grpc.add_KVServicer_to_server(KVServicer(), server)
+
+# Listen on a port
+server.add_insecure_port(':1234')
+
+# Start
+server.start()
+```
+
+You can listen on any TCP address or Unix domain socket. go-plugin does
+assume that connections are reliable (local), so you should not serve
+your plugin across the network.
+
+## 3. Add the gRPC Health Checking Service
+
+go-plugin requires the
+[gRPC Health Checking Service](https://github.com/grpc/grpc/blob/master/doc/health-checking.md)
+to be registered on your server. You must register the status of "plugin" to be SERVING.
+
+The health checking service is used by go-plugin to determine if everything
+is healthy with the connection. If you don't implement this service, your
+process may be abruptly restarted and your plugins are likely to be unreliable.
+
+```
+health = HealthServicer()
+health.set("plugin", health_pb2.HealthCheckResponse.ServingStatus.Value('SERVING'))
+health_pb2_grpc.add_HealthServicer_to_server(health, server)
+```
+
+## 4. Output Handshake Information
+
+The final step is to output the handshake information to stdout. go-plugin
+reads a single line from stdout to determine how to connect to your plugin,
+what protocol it is using, etc.
+
+
+The structure is:
+
+```
+CORE-PROTOCOL-VERSION | APP-PROTOCOL-VERSION | NETWORK-TYPE | NETWORK-ADDR | PROTOCOL
+```
+
+Where:
+
+ * `CORE-PROTOCOL-VERSION` is the protocol version for go-plugin itself.
+ The current value is `1`. Please use this value. Any other value will
+ cause your plugin to not load.
+
+ * `APP-PROTOCOL-VERSION` is the protocol version for the application data.
+ This is determined by the application. You must reference the documentation
+ for your application to determine the desired value.
+
+ * `NETWORK-TYPE` and `NETWORK-ADDR` are the networking information for
+ connecting to this plugin. The type must be "unix" or "tcp". The address
+ is a path to the Unix socket for "unix" and an IP address for "tcp".
+
+ * `PROTOCOL` is the named protocol that the connection will use. If this
+ is omitted (older versions), this is "netrpc" for Go net/rpc. This can
+ also be "grpc". This is the protocol that the plugin wants to speak to
+ the host process with.
+
+For our example that is:
+
+```
+1|1|tcp|127.0.0.1:1234|grpc
+```
+
+The only element you'll have to be careful about is the second one (the
+`APP-PROTOCOL-VERISON`). This will depend on the application you're
+building a plugin for. Please reference their documentation for more
+information.
+
+## 5. Done!
+
+And we're done!
+
+Configure the host application (the application you're writing a plugin
+for) to execute your Python application. Configuring plugins is specific
+to the host application.
+
+For our example, we used an environmental variable, and it looks like this:
+
+```sh
+$ export KV_PLUGIN="python plugin.py"
+```
--- /dev/null
+# go-plugin Internals
+
+This section discusses the internals of how go-plugin works.
+
+go-plugin operates by either _serving_ a plugin or being a _client_
+connecting to a remote plugin. The "client" is the host process or the
+process that itself uses plugins. The "server" is the plugin process.
+
+For a server:
+
+ 1. Output handshake to stdout
+ 2. Wait for connection on control address
+ 3. Serve plugins over control address
+
+For a client:
+
+ 1. Launch a plugin binary
+ 2. Read and verify handshake from plugin stdout
+ 3. Connect to plugin control address using desired protocol
+ 4. Dispense plugins using control connection
+
+## Handshake
+
+The handshake is the initial communication between a plugin and a host
+process to determine how the host process can connect and communicate to
+the plugin. This handshake is done over the plugin process's stdout.
+
+The `go-plugin` library itself handles the handshake when using the
+`Server` to serve a plugin. **You do not need to understand the internals
+of the handshake,** unless you're building a go-plugin compatible plugin
+in another language.
+
+The handshake is a single line of data terminated with a newline character
+`\n`. It looks like the following:
+
+```
+1|3|unix|/path/to/socket|grpc
+```
+
+The structure is:
+
+```
+CORE-PROTOCOL-VERSION | APP-PROTOCOL-VERSION | NETWORK-TYPE | NETWORK-ADDR | PROTOCOL
+```
+
+Where:
+
+ * `CORE-PROTOCOL-VERSION` is the protocol version for go-plugin itself.
+ The current value is `1`. Please use this value. Any other value will
+ cause your plugin to not load.
+
+ * `APP-PROTOCOL-VERSION` is the protocol version for the application data.
+ This is determined by the application. You must reference the documentation
+ for your application to determine the desired value.
+
+ * `NETWORK-TYPE` and `NETWORK-ADDR` are the networking information for
+ connecting to this plugin. The type must be "unix" or "tcp". The address
+ is a path to the Unix socket for "unix" and an IP address for "tcp".
+
+ * `PROTOCOL` is the named protocol that the connection will use. If this
+ is omitted (older versions), this is "netrpc" for Go net/rpc. This can
+ also be "grpc". This is the protocol that the plugin wants to speak to
+ the host process with.
--- /dev/null
+package plugin
+
+// This is a type that wraps error types so that they can be messaged
+// across RPC channels. Since "error" is an interface, we can't always
+// gob-encode the underlying structure. This is a valid error interface
+// implementer that we will push across.
+type BasicError struct {
+ Message string
+}
+
+// NewBasicError is used to create a BasicError.
+//
+// err is allowed to be nil.
+func NewBasicError(err error) *BasicError {
+ if err == nil {
+ return nil
+ }
+
+ return &BasicError{err.Error()}
+}
+
+func (e *BasicError) Error() string {
+ return e.Message
+}
--- /dev/null
+package plugin
+
+import (
+ "errors"
+ "testing"
+)
+
+func TestBasicError_ImplementsError(t *testing.T) {
+ var _ error = new(BasicError)
+}
+
+func TestBasicError_MatchesMessage(t *testing.T) {
+ err := errors.New("foo")
+ wrapped := NewBasicError(err)
+
+ if wrapped.Error() != err.Error() {
+ t.Fatalf("bad: %#v", wrapped.Error())
+ }
+}
+
+func TestNewBasicError_nil(t *testing.T) {
+ r := NewBasicError(nil)
+ if r != nil {
+ t.Fatalf("bad: %#v", r)
+ }
+}
--- /dev/null
+# Ignore binaries
+plugin/greeter
+basic
\ No newline at end of file
--- /dev/null
+Plugin Example
+--------------
+
+Compile the plugin itself via:
+
+ go build -o ./plugin/greeter ./plugin/greeter_impl.go
+
+Compile this driver via:
+
+ go build -o basic .
+
+You can then launch the plugin sample via:
+
+ ./basic
--- /dev/null
+package example
+
+import (
+ "net/rpc"
+
+ "github.com/hashicorp/go-plugin"
+)
+
+// Greeter is the interface that we're exposing as a plugin.
+type Greeter interface {
+ Greet() string
+}
+
+// Here is an implementation that talks over RPC
+type GreeterRPC struct{ client *rpc.Client }
+
+func (g *GreeterRPC) Greet() string {
+ var resp string
+ err := g.client.Call("Plugin.Greet", new(interface{}), &resp)
+ if err != nil {
+ // You usually want your interfaces to return errors. If they don't,
+ // there isn't much other choice here.
+ panic(err)
+ }
+
+ return resp
+}
+
+// Here is the RPC server that GreeterRPC talks to, conforming to
+// the requirements of net/rpc
+type GreeterRPCServer struct {
+ // This is the real implementation
+ Impl Greeter
+}
+
+func (s *GreeterRPCServer) Greet(args interface{}, resp *string) error {
+ *resp = s.Impl.Greet()
+ return nil
+}
+
+// This is the implementation of plugin.Plugin so we can serve/consume this
+//
+// This has two methods: Server must return an RPC server for this plugin
+// type. We construct a GreeterRPCServer for this.
+//
+// Client must return an implementation of our interface that communicates
+// over an RPC client. We return GreeterRPC for this.
+//
+// Ignore MuxBroker. That is used to create more multiplexed streams on our
+// plugin connection and is a more advanced use case.
+type GreeterPlugin struct {
+ // Impl Injection
+ Impl Greeter
+}
+
+func (p *GreeterPlugin) Server(*plugin.MuxBroker) (interface{}, error) {
+ return &GreeterRPCServer{Impl: p.Impl}, nil
+}
+
+func (GreeterPlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
+ return &GreeterRPC{client: c}, nil
+}
--- /dev/null
+package main
+
+import (
+ "fmt"
+ "log"
+ "os"
+ "os/exec"
+
+ hclog "github.com/hashicorp/go-hclog"
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/basic/commons"
+)
+
+func main() {
+ // Create an hclog.Logger
+ logger := hclog.New(&hclog.LoggerOptions{
+ Name: "plugin",
+ Output: os.Stdout,
+ Level: hclog.Debug,
+ })
+
+ // We're a host! Start by launching the plugin process.
+ client := plugin.NewClient(&plugin.ClientConfig{
+ HandshakeConfig: handshakeConfig,
+ Plugins: pluginMap,
+ Cmd: exec.Command("./plugin/greeter"),
+ Logger: logger,
+ })
+ defer client.Kill()
+
+ // Connect via RPC
+ rpcClient, err := client.Client()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Request the plugin
+ raw, err := rpcClient.Dispense("greeter")
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // We should have a Greeter now! This feels like a normal interface
+ // implementation but is in fact over an RPC connection.
+ greeter := raw.(example.Greeter)
+ fmt.Println(greeter.Greet())
+}
+
+// handshakeConfigs are used to just do a basic handshake between
+// a plugin and host. If the handshake fails, a user friendly error is shown.
+// This prevents users from executing bad plugins or executing a plugin
+// directory. It is a UX feature, not a security feature.
+var handshakeConfig = plugin.HandshakeConfig{
+ ProtocolVersion: 1,
+ MagicCookieKey: "BASIC_PLUGIN",
+ MagicCookieValue: "hello",
+}
+
+// pluginMap is the map of plugins we can dispense.
+var pluginMap = map[string]plugin.Plugin{
+ "greeter": &example.GreeterPlugin{},
+}
--- /dev/null
+package main
+
+import (
+ "os"
+
+ "github.com/hashicorp/go-hclog"
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/basic/commons"
+)
+
+// Here is a real implementation of Greeter
+type GreeterHello struct {
+ logger hclog.Logger
+}
+
+func (g *GreeterHello) Greet() string {
+ g.logger.Debug("message from GreeterHello.Greet")
+ return "Hello!"
+}
+
+// handshakeConfigs are used to just do a basic handshake between
+// a plugin and host. If the handshake fails, a user friendly error is shown.
+// This prevents users from executing bad plugins or executing a plugin
+// directory. It is a UX feature, not a security feature.
+var handshakeConfig = plugin.HandshakeConfig{
+ ProtocolVersion: 1,
+ MagicCookieKey: "BASIC_PLUGIN",
+ MagicCookieValue: "hello",
+}
+
+func main() {
+ logger := hclog.New(&hclog.LoggerOptions{
+ Level: hclog.Trace,
+ Output: os.Stderr,
+ JSONFormat: true,
+ })
+
+ greeter := &GreeterHello{
+ logger: logger,
+ }
+ // pluginMap is the map of plugins we can dispense.
+ var pluginMap = map[string]plugin.Plugin{
+ "greeter": &example.GreeterPlugin{Impl: greeter},
+ }
+
+ logger.Debug("message from plugin", "foo", "bar")
+
+ plugin.Serve(&plugin.ServeConfig{
+ HandshakeConfig: handshakeConfig,
+ Plugins: pluginMap,
+ })
+}
--- /dev/null
+# Counter Example
+
+This example builds a simple key/counter store CLI where the mechanism
+for storing and retrieving keys is pluggable. However, in this example we don't
+trust the plugin to do the summation work. We use bi-directional plugins to
+call back into the main proccess to do the sum of two numbers. To build this example:
+
+```sh
+# This builds the main CLI
+$ go build -o counter
+
+# This builds the plugin written in Go
+$ go build -o counter-go-grpc ./plugin-go-grpc
+
+# This tells the Counter binary to use the "counter-go-grpc" binary
+$ export COUNTER_PLUGIN="./counter-go-grpc"
+
+# Read and write
+$ ./counter put hello 1
+$ ./counter put hello 1
+
+$ ./counter get hello
+2
+```
+
+### Plugin: plugin-go-grpc
+
+This plugin uses gRPC to serve a plugin that is written in Go:
+
+```
+# This builds the plugin written in Go
+$ go build -o counter-go-grpc ./plugin-go-grpc
+
+# This tells the KV binary to use the "kv-go-grpc" binary
+$ export COUNTER_PLUGIN="./counter-go-grpc"
+```
+
+## Updating the Protocol
+
+If you update the protocol buffers file, you can regenerate the file
+using the following command from this directory. You do not need to run
+this if you're just trying the example.
+
+For Go:
+
+```sh
+$ protoc -I proto/ proto/kv.proto --go_out=plugins=grpc:proto/
+```
--- /dev/null
+package main
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "os/exec"
+ "strconv"
+
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/bidirectional/shared"
+)
+
+type addHelper struct{}
+
+func (*addHelper) Sum(a, b int64) (int64, error) {
+ return a + b, nil
+}
+
+func main() {
+ // We don't want to see the plugin logs.
+ log.SetOutput(ioutil.Discard)
+
+ // We're a host. Start by launching the plugin process.
+ client := plugin.NewClient(&plugin.ClientConfig{
+ HandshakeConfig: shared.Handshake,
+ Plugins: shared.PluginMap,
+ Cmd: exec.Command("sh", "-c", os.Getenv("COUNTER_PLUGIN")),
+ AllowedProtocols: []plugin.Protocol{
+ plugin.ProtocolNetRPC, plugin.ProtocolGRPC},
+ })
+ defer client.Kill()
+
+ // Connect via RPC
+ rpcClient, err := client.Client()
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ // Request the plugin
+ raw, err := rpcClient.Dispense("counter")
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ // We should have a Counter store now! This feels like a normal interface
+ // implementation but is in fact over an RPC connection.
+ counter := raw.(shared.Counter)
+
+ os.Args = os.Args[1:]
+ switch os.Args[0] {
+ case "get":
+ result, err := counter.Get(os.Args[1])
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ fmt.Println(result)
+
+ case "put":
+ i, err := strconv.Atoi(os.Args[2])
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ err = counter.Put(os.Args[1], int64(i), &addHelper{})
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ default:
+ fmt.Println("Please only use 'get' or 'put'")
+ os.Exit(1)
+ }
+}
--- /dev/null
+package main
+
+import (
+ "encoding/json"
+ "io/ioutil"
+
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/bidirectional/shared"
+)
+
+// Here is a real implementation of KV that writes to a local file with
+// the key name and the contents are the value of the key.
+type Counter struct {
+}
+
+type data struct {
+ Value int64
+}
+
+func (k *Counter) Put(key string, value int64, a shared.AddHelper) error {
+ v, _ := k.Get(key)
+
+ r, err := a.Sum(v, value)
+ if err != nil {
+ return err
+ }
+
+ buf, err := json.Marshal(&data{r})
+ if err != nil {
+ return err
+ }
+
+ return ioutil.WriteFile("kv_"+key, buf, 0644)
+}
+
+func (k *Counter) Get(key string) (int64, error) {
+ dataRaw, err := ioutil.ReadFile("kv_" + key)
+ if err != nil {
+ return 0, err
+ }
+
+ data := &data{}
+ err = json.Unmarshal(dataRaw, data)
+ if err != nil {
+ return 0, err
+ }
+
+ return data.Value, nil
+}
+
+func main() {
+ plugin.Serve(&plugin.ServeConfig{
+ HandshakeConfig: shared.Handshake,
+ Plugins: map[string]plugin.Plugin{
+ "counter": &shared.CounterPlugin{Impl: &Counter{}},
+ },
+
+ // A non-nil value here enables gRPC serving for this plugin...
+ GRPCServer: plugin.DefaultGRPCServer,
+ })
+}
--- /dev/null
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: kv.proto
+
+/*
+Package proto is a generated protocol buffer package.
+
+It is generated from these files:
+ kv.proto
+
+It has these top-level messages:
+ GetRequest
+ GetResponse
+ PutRequest
+ Empty
+ SumRequest
+ SumResponse
+*/
+package proto
+
+import proto1 "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+import (
+ context "golang.org/x/net/context"
+ grpc "google.golang.org/grpc"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto1.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto1.ProtoPackageIsVersion2 // please upgrade the proto package
+
+type GetRequest struct {
+ Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"`
+}
+
+func (m *GetRequest) Reset() { *m = GetRequest{} }
+func (m *GetRequest) String() string { return proto1.CompactTextString(m) }
+func (*GetRequest) ProtoMessage() {}
+func (*GetRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+func (m *GetRequest) GetKey() string {
+ if m != nil {
+ return m.Key
+ }
+ return ""
+}
+
+type GetResponse struct {
+ Value int64 `protobuf:"varint,1,opt,name=value" json:"value,omitempty"`
+}
+
+func (m *GetResponse) Reset() { *m = GetResponse{} }
+func (m *GetResponse) String() string { return proto1.CompactTextString(m) }
+func (*GetResponse) ProtoMessage() {}
+func (*GetResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
+
+func (m *GetResponse) GetValue() int64 {
+ if m != nil {
+ return m.Value
+ }
+ return 0
+}
+
+type PutRequest struct {
+ AddServer uint32 `protobuf:"varint,1,opt,name=add_server,json=addServer" json:"add_server,omitempty"`
+ Key string `protobuf:"bytes,2,opt,name=key" json:"key,omitempty"`
+ Value int64 `protobuf:"varint,3,opt,name=value" json:"value,omitempty"`
+}
+
+func (m *PutRequest) Reset() { *m = PutRequest{} }
+func (m *PutRequest) String() string { return proto1.CompactTextString(m) }
+func (*PutRequest) ProtoMessage() {}
+func (*PutRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
+
+func (m *PutRequest) GetAddServer() uint32 {
+ if m != nil {
+ return m.AddServer
+ }
+ return 0
+}
+
+func (m *PutRequest) GetKey() string {
+ if m != nil {
+ return m.Key
+ }
+ return ""
+}
+
+func (m *PutRequest) GetValue() int64 {
+ if m != nil {
+ return m.Value
+ }
+ return 0
+}
+
+type Empty struct {
+}
+
+func (m *Empty) Reset() { *m = Empty{} }
+func (m *Empty) String() string { return proto1.CompactTextString(m) }
+func (*Empty) ProtoMessage() {}
+func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
+
+type SumRequest struct {
+ A int64 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
+ B int64 `protobuf:"varint,2,opt,name=b" json:"b,omitempty"`
+}
+
+func (m *SumRequest) Reset() { *m = SumRequest{} }
+func (m *SumRequest) String() string { return proto1.CompactTextString(m) }
+func (*SumRequest) ProtoMessage() {}
+func (*SumRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
+
+func (m *SumRequest) GetA() int64 {
+ if m != nil {
+ return m.A
+ }
+ return 0
+}
+
+func (m *SumRequest) GetB() int64 {
+ if m != nil {
+ return m.B
+ }
+ return 0
+}
+
+type SumResponse struct {
+ R int64 `protobuf:"varint,1,opt,name=r" json:"r,omitempty"`
+}
+
+func (m *SumResponse) Reset() { *m = SumResponse{} }
+func (m *SumResponse) String() string { return proto1.CompactTextString(m) }
+func (*SumResponse) ProtoMessage() {}
+func (*SumResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
+
+func (m *SumResponse) GetR() int64 {
+ if m != nil {
+ return m.R
+ }
+ return 0
+}
+
+func init() {
+ proto1.RegisterType((*GetRequest)(nil), "proto.GetRequest")
+ proto1.RegisterType((*GetResponse)(nil), "proto.GetResponse")
+ proto1.RegisterType((*PutRequest)(nil), "proto.PutRequest")
+ proto1.RegisterType((*Empty)(nil), "proto.Empty")
+ proto1.RegisterType((*SumRequest)(nil), "proto.SumRequest")
+ proto1.RegisterType((*SumResponse)(nil), "proto.SumResponse")
+}
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+const _ = grpc.SupportPackageIsVersion4
+
+// Client API for Counter service
+
+type CounterClient interface {
+ Get(ctx context.Context, in *GetRequest, opts ...grpc.CallOption) (*GetResponse, error)
+ Put(ctx context.Context, in *PutRequest, opts ...grpc.CallOption) (*Empty, error)
+}
+
+type counterClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewCounterClient(cc *grpc.ClientConn) CounterClient {
+ return &counterClient{cc}
+}
+
+func (c *counterClient) Get(ctx context.Context, in *GetRequest, opts ...grpc.CallOption) (*GetResponse, error) {
+ out := new(GetResponse)
+ err := grpc.Invoke(ctx, "/proto.Counter/Get", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *counterClient) Put(ctx context.Context, in *PutRequest, opts ...grpc.CallOption) (*Empty, error) {
+ out := new(Empty)
+ err := grpc.Invoke(ctx, "/proto.Counter/Put", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// Server API for Counter service
+
+type CounterServer interface {
+ Get(context.Context, *GetRequest) (*GetResponse, error)
+ Put(context.Context, *PutRequest) (*Empty, error)
+}
+
+func RegisterCounterServer(s *grpc.Server, srv CounterServer) {
+ s.RegisterService(&_Counter_serviceDesc, srv)
+}
+
+func _Counter_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(GetRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(CounterServer).Get(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/proto.Counter/Get",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(CounterServer).Get(ctx, req.(*GetRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _Counter_Put_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(PutRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(CounterServer).Put(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/proto.Counter/Put",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(CounterServer).Put(ctx, req.(*PutRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+var _Counter_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "proto.Counter",
+ HandlerType: (*CounterServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "Get",
+ Handler: _Counter_Get_Handler,
+ },
+ {
+ MethodName: "Put",
+ Handler: _Counter_Put_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{},
+ Metadata: "kv.proto",
+}
+
+// Client API for AddHelper service
+
+type AddHelperClient interface {
+ Sum(ctx context.Context, in *SumRequest, opts ...grpc.CallOption) (*SumResponse, error)
+}
+
+type addHelperClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewAddHelperClient(cc *grpc.ClientConn) AddHelperClient {
+ return &addHelperClient{cc}
+}
+
+func (c *addHelperClient) Sum(ctx context.Context, in *SumRequest, opts ...grpc.CallOption) (*SumResponse, error) {
+ out := new(SumResponse)
+ err := grpc.Invoke(ctx, "/proto.AddHelper/Sum", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// Server API for AddHelper service
+
+type AddHelperServer interface {
+ Sum(context.Context, *SumRequest) (*SumResponse, error)
+}
+
+func RegisterAddHelperServer(s *grpc.Server, srv AddHelperServer) {
+ s.RegisterService(&_AddHelper_serviceDesc, srv)
+}
+
+func _AddHelper_Sum_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(SumRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(AddHelperServer).Sum(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/proto.AddHelper/Sum",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(AddHelperServer).Sum(ctx, req.(*SumRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+var _AddHelper_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "proto.AddHelper",
+ HandlerType: (*AddHelperServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "Sum",
+ Handler: _AddHelper_Sum_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{},
+ Metadata: "kv.proto",
+}
+
+func init() { proto1.RegisterFile("kv.proto", fileDescriptor0) }
+
+var fileDescriptor0 = []byte{
+ // 253 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x4c, 0x90, 0x4f, 0x4b, 0x03, 0x31,
+ 0x10, 0xc5, 0x89, 0x61, 0xad, 0xfb, 0xba, 0x82, 0x06, 0x0f, 0x52, 0x51, 0x24, 0x82, 0xf4, 0x20,
+ 0x3d, 0xd4, 0x93, 0x47, 0x11, 0xa9, 0xc7, 0x92, 0xfd, 0x00, 0x25, 0x4b, 0xe6, 0xd4, 0x6e, 0x77,
+ 0xcd, 0x26, 0x0b, 0xfd, 0xf6, 0xa5, 0xd9, 0x3f, 0xd9, 0x53, 0x32, 0x2f, 0x2f, 0xbf, 0x37, 0x33,
+ 0xb8, 0xd9, 0xb7, 0xab, 0xda, 0x56, 0xae, 0x12, 0x49, 0x38, 0xe4, 0x0b, 0xb0, 0x21, 0xa7, 0xe8,
+ 0xdf, 0x53, 0xe3, 0xc4, 0x1d, 0xf8, 0x9e, 0x4e, 0x8f, 0xec, 0x95, 0x2d, 0x53, 0x75, 0xb9, 0xca,
+ 0x37, 0xcc, 0xc3, 0x7b, 0x53, 0x57, 0xc7, 0x86, 0xc4, 0x03, 0x92, 0x56, 0x1f, 0x3c, 0x05, 0x0b,
+ 0x57, 0x5d, 0x21, 0x73, 0x60, 0xeb, 0x47, 0xc8, 0x33, 0xa0, 0x8d, 0xd9, 0x35, 0x64, 0x5b, 0xb2,
+ 0xc1, 0x78, 0xab, 0x52, 0x6d, 0x4c, 0x1e, 0x84, 0x21, 0xe3, 0x6a, 0xcc, 0x88, 0x50, 0x3e, 0x85,
+ 0xce, 0x90, 0xfc, 0x96, 0xb5, 0x3b, 0xc9, 0x25, 0x90, 0xfb, 0x72, 0xa0, 0x67, 0x60, 0xba, 0x4f,
+ 0x67, 0xfa, 0x52, 0x15, 0x01, 0xc5, 0x15, 0x2b, 0xe4, 0x13, 0xe6, 0xc1, 0xd9, 0x37, 0x9b, 0x81,
+ 0xd9, 0xc1, 0x6a, 0xd7, 0x3b, 0xcc, 0x7e, 0x2a, 0x7f, 0x74, 0x64, 0xc5, 0x07, 0xf8, 0x86, 0x9c,
+ 0xb8, 0xef, 0x56, 0xb1, 0x8a, 0x0b, 0x58, 0x88, 0xa9, 0xd4, 0x63, 0xde, 0xc1, 0xb7, 0x3e, 0xba,
+ 0xe3, 0xa4, 0x8b, 0xac, 0x97, 0x42, 0x9f, 0xeb, 0x2f, 0xa4, 0xdf, 0xc6, 0xfc, 0xd1, 0xa1, 0xee,
+ 0x22, 0x72, 0x5f, 0x8e, 0x9f, 0xe2, 0x00, 0x63, 0xc4, 0xa4, 0xd3, 0xe2, 0x3a, 0x48, 0x9f, 0xe7,
+ 0x00, 0x00, 0x00, 0xff, 0xff, 0x40, 0xa3, 0x85, 0x07, 0x9f, 0x01, 0x00, 0x00,
+}
--- /dev/null
+syntax = "proto3";
+package proto;
+
+message GetRequest {
+ string key = 1;
+}
+
+message GetResponse {
+ int64 value = 1;
+}
+
+message PutRequest {
+ uint32 add_server = 1;
+ string key = 2;
+ int64 value = 3;
+}
+
+message Empty {}
+
+message SumRequest {
+ int64 a = 1;
+ int64 b = 2;
+}
+
+message SumResponse {
+ int64 r = 1;
+}
+
+service Counter {
+ rpc Get(GetRequest) returns (GetResponse);
+ rpc Put(PutRequest) returns (Empty);
+}
+
+service AddHelper {
+ rpc Sum(SumRequest) returns (SumResponse);
+}
--- /dev/null
+package shared
+
+import (
+ hclog "github.com/hashicorp/go-hclog"
+ plugin "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/bidirectional/proto"
+ "golang.org/x/net/context"
+ "google.golang.org/grpc"
+)
+
+// GRPCClient is an implementation of KV that talks over RPC.
+type GRPCClient struct {
+ broker *plugin.GRPCBroker
+ client proto.CounterClient
+}
+
+func (m *GRPCClient) Put(key string, value int64, a AddHelper) error {
+ addHelperServer := &GRPCAddHelperServer{Impl: a}
+
+ var s *grpc.Server
+ serverFunc := func(opts []grpc.ServerOption) *grpc.Server {
+ s = grpc.NewServer(opts...)
+ proto.RegisterAddHelperServer(s, addHelperServer)
+
+ return s
+ }
+
+ brokerID := m.broker.NextId()
+ go m.broker.AcceptAndServe(brokerID, serverFunc)
+
+ _, err := m.client.Put(context.Background(), &proto.PutRequest{
+ AddServer: brokerID,
+ Key: key,
+ Value: value,
+ })
+
+ s.Stop()
+ return err
+}
+
+func (m *GRPCClient) Get(key string) (int64, error) {
+ resp, err := m.client.Get(context.Background(), &proto.GetRequest{
+ Key: key,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ return resp.Value, nil
+}
+
+// Here is the gRPC server that GRPCClient talks to.
+type GRPCServer struct {
+ // This is the real implementation
+ Impl Counter
+
+ broker *plugin.GRPCBroker
+}
+
+func (m *GRPCServer) Put(ctx context.Context, req *proto.PutRequest) (*proto.Empty, error) {
+ conn, err := m.broker.Dial(req.AddServer)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ a := &GRPCAddHelperClient{proto.NewAddHelperClient(conn)}
+ return &proto.Empty{}, m.Impl.Put(req.Key, req.Value, a)
+}
+
+func (m *GRPCServer) Get(ctx context.Context, req *proto.GetRequest) (*proto.GetResponse, error) {
+ v, err := m.Impl.Get(req.Key)
+ return &proto.GetResponse{Value: v}, err
+}
+
+// GRPCClient is an implementation of KV that talks over RPC.
+type GRPCAddHelperClient struct{ client proto.AddHelperClient }
+
+func (m *GRPCAddHelperClient) Sum(a, b int64) (int64, error) {
+ resp, err := m.client.Sum(context.Background(), &proto.SumRequest{
+ A: a,
+ B: b,
+ })
+ if err != nil {
+ hclog.Default().Info("add.Sum", "client", "start", "err", err)
+ return 0, err
+ }
+ return resp.R, err
+}
+
+// Here is the gRPC server that GRPCClient talks to.
+type GRPCAddHelperServer struct {
+ // This is the real implementation
+ Impl AddHelper
+}
+
+func (m *GRPCAddHelperServer) Sum(ctx context.Context, req *proto.SumRequest) (resp *proto.SumResponse, err error) {
+ r, err := m.Impl.Sum(req.A, req.B)
+ if err != nil {
+ return nil, err
+ }
+ return &proto.SumResponse{R: r}, err
+}
--- /dev/null
+// Package shared contains shared data between the host and plugins.
+package shared
+
+import (
+ "golang.org/x/net/context"
+ "google.golang.org/grpc"
+
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/bidirectional/proto"
+)
+
+// Handshake is a common handshake that is shared by plugin and host.
+var Handshake = plugin.HandshakeConfig{
+ ProtocolVersion: 1,
+ MagicCookieKey: "BASIC_PLUGIN",
+ MagicCookieValue: "hello",
+}
+
+// PluginMap is the map of plugins we can dispense.
+var PluginMap = map[string]plugin.Plugin{
+ "counter": &CounterPlugin{},
+}
+
+type AddHelper interface {
+ Sum(int64, int64) (int64, error)
+}
+
+// KV is the interface that we're exposing as a plugin.
+type Counter interface {
+ Put(key string, value int64, a AddHelper) error
+ Get(key string) (int64, error)
+}
+
+// This is the implementation of plugin.Plugin so we can serve/consume this.
+// We also implement GRPCPlugin so that this plugin can be served over
+// gRPC.
+type CounterPlugin struct {
+ plugin.NetRPCUnsupportedPlugin
+ // Concrete implementation, written in Go. This is only used for plugins
+ // that are written in Go.
+ Impl Counter
+}
+
+func (p *CounterPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {
+ proto.RegisterCounterServer(s, &GRPCServer{
+ Impl: p.Impl,
+ broker: broker,
+ })
+ return nil
+}
+
+func (p *CounterPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
+ return &GRPCClient{
+ client: proto.NewCounterClient(c),
+ broker: broker,
+ }, nil
+}
+
+var _ plugin.GRPCPlugin = &CounterPlugin{}
--- /dev/null
+*.pyc
+kv
+kv-*
+kv_*
+!kv_*.py
--- /dev/null
+# KV Example
+
+This example builds a simple key/value store CLI where the mechanism
+for storing and retrieving keys is pluggable. To build this example:
+
+```sh
+# This builds the main CLI
+$ go build -o kv
+
+# This builds the plugin written in Go
+$ go build -o kv-go-grpc ./plugin-go-grpc
+
+# This tells the KV binary to use the "kv-go-grpc" binary
+$ export KV_PLUGIN="./kv-go-grpc"
+
+# Read and write
+$ ./kv put hello world
+
+$ ./kv get hello
+world
+```
+
+### Plugin: plugin-go-grpc
+
+This plugin uses gRPC to serve a plugin that is written in Go:
+
+```
+# This builds the plugin written in Go
+$ go build -o kv-go-grpc ./plugin-go-grpc
+
+# This tells the KV binary to use the "kv-go-grpc" binary
+$ export KV_PLUGIN="./kv-go-grpc"
+```
+
+### Plugin: plugin-go-netrpc
+
+This plugin uses the builtin Go net/rpc mechanism to serve the plugin:
+
+```
+# This builds the plugin written in Go
+$ go build -o kv-go-netrpc ./plugin-go-netrpc
+
+# This tells the KV binary to use the "kv-go-netrpc" binary
+$ export KV_PLUGIN="./kv-go-netrpc"
+```
+
+### Plugin: plugin-python
+
+This plugin is written in Python:
+
+```
+$ export KV_PLUGIN="python plugin-python/plugin.py"
+```
+
+## Updating the Protocol
+
+If you update the protocol buffers file, you can regenerate the file
+using the following command from this directory. You do not need to run
+this if you're just trying the example.
+
+For Go:
+
+```sh
+$ protoc -I proto/ proto/kv.proto --go_out=plugins=grpc:proto/
+```
+
+For Python:
+
+```sh
+$ python -m grpc_tools.protoc -I ./proto/ --python_out=./plugin-python/ --grpc_python_out=./plugin-python/ ./proto/kv.proto
+```
--- /dev/null
+package main
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "os/exec"
+
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/grpc/shared"
+)
+
+func main() {
+ // We don't want to see the plugin logs.
+ log.SetOutput(ioutil.Discard)
+
+ // We're a host. Start by launching the plugin process.
+ client := plugin.NewClient(&plugin.ClientConfig{
+ HandshakeConfig: shared.Handshake,
+ Plugins: shared.PluginMap,
+ Cmd: exec.Command("sh", "-c", os.Getenv("KV_PLUGIN")),
+ AllowedProtocols: []plugin.Protocol{
+ plugin.ProtocolNetRPC, plugin.ProtocolGRPC},
+ })
+ defer client.Kill()
+
+ // Connect via RPC
+ rpcClient, err := client.Client()
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ // Request the plugin
+ raw, err := rpcClient.Dispense("kv")
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ // We should have a KV store now! This feels like a normal interface
+ // implementation but is in fact over an RPC connection.
+ kv := raw.(shared.KV)
+ os.Args = os.Args[1:]
+ switch os.Args[0] {
+ case "get":
+ result, err := kv.Get(os.Args[1])
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ fmt.Println(string(result))
+
+ case "put":
+ err := kv.Put(os.Args[1], []byte(os.Args[2]))
+ if err != nil {
+ fmt.Println("Error:", err.Error())
+ os.Exit(1)
+ }
+
+ default:
+ fmt.Println("Please only use 'get' or 'put'")
+ os.Exit(1)
+ }
+}
--- /dev/null
+package main
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/grpc/shared"
+)
+
+// Here is a real implementation of KV that writes to a local file with
+// the key name and the contents are the value of the key.
+type KV struct{}
+
+func (KV) Put(key string, value []byte) error {
+ value = []byte(fmt.Sprintf("%s\n\nWritten from plugin-go-grpc", string(value)))
+ return ioutil.WriteFile("kv_"+key, value, 0644)
+}
+
+func (KV) Get(key string) ([]byte, error) {
+ return ioutil.ReadFile("kv_" + key)
+}
+
+func main() {
+ plugin.Serve(&plugin.ServeConfig{
+ HandshakeConfig: shared.Handshake,
+ Plugins: map[string]plugin.Plugin{
+ "kv": &shared.KVPlugin{Impl: &KV{}},
+ },
+
+ // A non-nil value here enables gRPC serving for this plugin...
+ GRPCServer: plugin.DefaultGRPCServer,
+ })
+}
--- /dev/null
+package main
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/grpc/shared"
+)
+
+// Here is a real implementation of KV that writes to a local file with
+// the key name and the contents are the value of the key.
+type KV struct{}
+
+func (KV) Put(key string, value []byte) error {
+ value = []byte(fmt.Sprintf("%s\n\nWritten from plugin-go-netrpc", string(value)))
+ return ioutil.WriteFile("kv_"+key, value, 0644)
+}
+
+func (KV) Get(key string) ([]byte, error) {
+ return ioutil.ReadFile("kv_" + key)
+}
+
+func main() {
+ plugin.Serve(&plugin.ServeConfig{
+ HandshakeConfig: shared.Handshake,
+ Plugins: map[string]plugin.Plugin{
+ "kv": &shared.KVPlugin{Impl: &KV{}},
+ },
+ })
+}
--- /dev/null
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: kv.proto
+
+import sys
+_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf import descriptor_pb2
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name='kv.proto',
+ package='proto',
+ syntax='proto3',
+ serialized_pb=_b('\n\x08kv.proto\x12\x05proto\"\x19\n\nGetRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\"\x1c\n\x0bGetResponse\x12\r\n\x05value\x18\x01 \x01(\x0c\"(\n\nPutRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty2Z\n\x02KV\x12,\n\x03Get\x12\x11.proto.GetRequest\x1a\x12.proto.GetResponse\x12&\n\x03Put\x12\x11.proto.PutRequest\x1a\x0c.proto.Emptyb\x06proto3')
+)
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+
+
+
+_GETREQUEST = _descriptor.Descriptor(
+ name='GetRequest',
+ full_name='proto.GetRequest',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='key', full_name='proto.GetRequest.key', index=0,
+ number=1, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=19,
+ serialized_end=44,
+)
+
+
+_GETRESPONSE = _descriptor.Descriptor(
+ name='GetResponse',
+ full_name='proto.GetResponse',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='value', full_name='proto.GetResponse.value', index=0,
+ number=1, type=12, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b(""),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=46,
+ serialized_end=74,
+)
+
+
+_PUTREQUEST = _descriptor.Descriptor(
+ name='PutRequest',
+ full_name='proto.PutRequest',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='key', full_name='proto.PutRequest.key', index=0,
+ number=1, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b("").decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None),
+ _descriptor.FieldDescriptor(
+ name='value', full_name='proto.PutRequest.value', index=1,
+ number=2, type=12, cpp_type=9, label=1,
+ has_default_value=False, default_value=_b(""),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=76,
+ serialized_end=116,
+)
+
+
+_EMPTY = _descriptor.Descriptor(
+ name='Empty',
+ full_name='proto.Empty',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=118,
+ serialized_end=125,
+)
+
+DESCRIPTOR.message_types_by_name['GetRequest'] = _GETREQUEST
+DESCRIPTOR.message_types_by_name['GetResponse'] = _GETRESPONSE
+DESCRIPTOR.message_types_by_name['PutRequest'] = _PUTREQUEST
+DESCRIPTOR.message_types_by_name['Empty'] = _EMPTY
+
+GetRequest = _reflection.GeneratedProtocolMessageType('GetRequest', (_message.Message,), dict(
+ DESCRIPTOR = _GETREQUEST,
+ __module__ = 'kv_pb2'
+ # @@protoc_insertion_point(class_scope:proto.GetRequest)
+ ))
+_sym_db.RegisterMessage(GetRequest)
+
+GetResponse = _reflection.GeneratedProtocolMessageType('GetResponse', (_message.Message,), dict(
+ DESCRIPTOR = _GETRESPONSE,
+ __module__ = 'kv_pb2'
+ # @@protoc_insertion_point(class_scope:proto.GetResponse)
+ ))
+_sym_db.RegisterMessage(GetResponse)
+
+PutRequest = _reflection.GeneratedProtocolMessageType('PutRequest', (_message.Message,), dict(
+ DESCRIPTOR = _PUTREQUEST,
+ __module__ = 'kv_pb2'
+ # @@protoc_insertion_point(class_scope:proto.PutRequest)
+ ))
+_sym_db.RegisterMessage(PutRequest)
+
+Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), dict(
+ DESCRIPTOR = _EMPTY,
+ __module__ = 'kv_pb2'
+ # @@protoc_insertion_point(class_scope:proto.Empty)
+ ))
+_sym_db.RegisterMessage(Empty)
+
+
+try:
+ # THESE ELEMENTS WILL BE DEPRECATED.
+ # Please use the generated *_pb2_grpc.py files instead.
+ import grpc
+ from grpc.beta import implementations as beta_implementations
+ from grpc.beta import interfaces as beta_interfaces
+ from grpc.framework.common import cardinality
+ from grpc.framework.interfaces.face import utilities as face_utilities
+
+
+ class KVStub(object):
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.Get = channel.unary_unary(
+ '/proto.KV/Get',
+ request_serializer=GetRequest.SerializeToString,
+ response_deserializer=GetResponse.FromString,
+ )
+ self.Put = channel.unary_unary(
+ '/proto.KV/Put',
+ request_serializer=PutRequest.SerializeToString,
+ response_deserializer=Empty.FromString,
+ )
+
+
+ class KVServicer(object):
+
+ def Get(self, request, context):
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def Put(self, request, context):
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+ def add_KVServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'Get': grpc.unary_unary_rpc_method_handler(
+ servicer.Get,
+ request_deserializer=GetRequest.FromString,
+ response_serializer=GetResponse.SerializeToString,
+ ),
+ 'Put': grpc.unary_unary_rpc_method_handler(
+ servicer.Put,
+ request_deserializer=PutRequest.FromString,
+ response_serializer=Empty.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'proto.KV', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+ class BetaKVServicer(object):
+ """The Beta API is deprecated for 0.15.0 and later.
+
+ It is recommended to use the GA API (classes and functions in this
+ file not marked beta) for all further purposes. This class was generated
+ only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0."""
+ def Get(self, request, context):
+ context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
+ def Put(self, request, context):
+ context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
+
+
+ class BetaKVStub(object):
+ """The Beta API is deprecated for 0.15.0 and later.
+
+ It is recommended to use the GA API (classes and functions in this
+ file not marked beta) for all further purposes. This class was generated
+ only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0."""
+ def Get(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
+ raise NotImplementedError()
+ Get.future = None
+ def Put(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
+ raise NotImplementedError()
+ Put.future = None
+
+
+ def beta_create_KV_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None):
+ """The Beta API is deprecated for 0.15.0 and later.
+
+ It is recommended to use the GA API (classes and functions in this
+ file not marked beta) for all further purposes. This function was
+ generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
+ request_deserializers = {
+ ('proto.KV', 'Get'): GetRequest.FromString,
+ ('proto.KV', 'Put'): PutRequest.FromString,
+ }
+ response_serializers = {
+ ('proto.KV', 'Get'): GetResponse.SerializeToString,
+ ('proto.KV', 'Put'): Empty.SerializeToString,
+ }
+ method_implementations = {
+ ('proto.KV', 'Get'): face_utilities.unary_unary_inline(servicer.Get),
+ ('proto.KV', 'Put'): face_utilities.unary_unary_inline(servicer.Put),
+ }
+ server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout)
+ return beta_implementations.server(method_implementations, options=server_options)
+
+
+ def beta_create_KV_stub(channel, host=None, metadata_transformer=None, pool=None, pool_size=None):
+ """The Beta API is deprecated for 0.15.0 and later.
+
+ It is recommended to use the GA API (classes and functions in this
+ file not marked beta) for all further purposes. This function was
+ generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
+ request_serializers = {
+ ('proto.KV', 'Get'): GetRequest.SerializeToString,
+ ('proto.KV', 'Put'): PutRequest.SerializeToString,
+ }
+ response_deserializers = {
+ ('proto.KV', 'Get'): GetResponse.FromString,
+ ('proto.KV', 'Put'): Empty.FromString,
+ }
+ cardinalities = {
+ 'Get': cardinality.Cardinality.UNARY_UNARY,
+ 'Put': cardinality.Cardinality.UNARY_UNARY,
+ }
+ stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size)
+ return beta_implementations.dynamic_stub(channel, 'proto.KV', cardinalities, options=stub_options)
+except ImportError:
+ pass
+# @@protoc_insertion_point(module_scope)
--- /dev/null
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+import grpc
+
+import kv_pb2 as kv__pb2
+
+
+class KVStub(object):
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.Get = channel.unary_unary(
+ '/proto.KV/Get',
+ request_serializer=kv__pb2.GetRequest.SerializeToString,
+ response_deserializer=kv__pb2.GetResponse.FromString,
+ )
+ self.Put = channel.unary_unary(
+ '/proto.KV/Put',
+ request_serializer=kv__pb2.PutRequest.SerializeToString,
+ response_deserializer=kv__pb2.Empty.FromString,
+ )
+
+
+class KVServicer(object):
+
+ def Get(self, request, context):
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def Put(self, request, context):
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_KVServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'Get': grpc.unary_unary_rpc_method_handler(
+ servicer.Get,
+ request_deserializer=kv__pb2.GetRequest.FromString,
+ response_serializer=kv__pb2.GetResponse.SerializeToString,
+ ),
+ 'Put': grpc.unary_unary_rpc_method_handler(
+ servicer.Put,
+ request_deserializer=kv__pb2.PutRequest.FromString,
+ response_serializer=kv__pb2.Empty.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'proto.KV', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
--- /dev/null
+from concurrent import futures
+import sys
+import time
+
+import grpc
+
+import kv_pb2
+import kv_pb2_grpc
+
+from grpc_health.v1.health import HealthServicer
+from grpc_health.v1 import health_pb2, health_pb2_grpc
+
+class KVServicer(kv_pb2_grpc.KVServicer):
+ """Implementation of KV service."""
+
+ def Get(self, request, context):
+ filename = "kv_"+request.key
+ with open(filename, 'r+b') as f:
+ result = kv_pb2.GetResponse()
+ result.value = f.read()
+ return result
+
+ def Put(self, request, context):
+ filename = "kv_"+request.key
+ value = "{0}\n\nWritten from plugin-python".format(request.value)
+ with open(filename, 'w') as f:
+ f.write(value)
+
+ return kv_pb2.Empty()
+
+def serve():
+ # We need to build a health service to work with go-plugin
+ health = HealthServicer()
+ health.set("plugin", health_pb2.HealthCheckResponse.ServingStatus.Value('SERVING'))
+
+ # Start the server.
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ kv_pb2_grpc.add_KVServicer_to_server(KVServicer(), server)
+ health_pb2_grpc.add_HealthServicer_to_server(health, server)
+ server.add_insecure_port('127.0.0.1:1234')
+ server.start()
+
+ # Output information
+ print("1|1|tcp|127.0.0.1:1234|grpc")
+ sys.stdout.flush()
+
+ try:
+ while True:
+ time.sleep(60 * 60 * 24)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+if __name__ == '__main__':
+ serve()
--- /dev/null
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: kv.proto
+
+/*
+Package proto is a generated protocol buffer package.
+
+It is generated from these files:
+ kv.proto
+
+It has these top-level messages:
+ GetRequest
+ GetResponse
+ PutRequest
+ Empty
+*/
+package proto
+
+import proto1 "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+import (
+ context "golang.org/x/net/context"
+ grpc "google.golang.org/grpc"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto1.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto1.ProtoPackageIsVersion2 // please upgrade the proto package
+
+type GetRequest struct {
+ Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"`
+}
+
+func (m *GetRequest) Reset() { *m = GetRequest{} }
+func (m *GetRequest) String() string { return proto1.CompactTextString(m) }
+func (*GetRequest) ProtoMessage() {}
+func (*GetRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+func (m *GetRequest) GetKey() string {
+ if m != nil {
+ return m.Key
+ }
+ return ""
+}
+
+type GetResponse struct {
+ Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
+}
+
+func (m *GetResponse) Reset() { *m = GetResponse{} }
+func (m *GetResponse) String() string { return proto1.CompactTextString(m) }
+func (*GetResponse) ProtoMessage() {}
+func (*GetResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
+
+func (m *GetResponse) GetValue() []byte {
+ if m != nil {
+ return m.Value
+ }
+ return nil
+}
+
+type PutRequest struct {
+ Key string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty"`
+ Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
+}
+
+func (m *PutRequest) Reset() { *m = PutRequest{} }
+func (m *PutRequest) String() string { return proto1.CompactTextString(m) }
+func (*PutRequest) ProtoMessage() {}
+func (*PutRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
+
+func (m *PutRequest) GetKey() string {
+ if m != nil {
+ return m.Key
+ }
+ return ""
+}
+
+func (m *PutRequest) GetValue() []byte {
+ if m != nil {
+ return m.Value
+ }
+ return nil
+}
+
+type Empty struct {
+}
+
+func (m *Empty) Reset() { *m = Empty{} }
+func (m *Empty) String() string { return proto1.CompactTextString(m) }
+func (*Empty) ProtoMessage() {}
+func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
+
+func init() {
+ proto1.RegisterType((*GetRequest)(nil), "proto.GetRequest")
+ proto1.RegisterType((*GetResponse)(nil), "proto.GetResponse")
+ proto1.RegisterType((*PutRequest)(nil), "proto.PutRequest")
+ proto1.RegisterType((*Empty)(nil), "proto.Empty")
+}
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+const _ = grpc.SupportPackageIsVersion4
+
+// Client API for KV service
+
+type KVClient interface {
+ Get(ctx context.Context, in *GetRequest, opts ...grpc.CallOption) (*GetResponse, error)
+ Put(ctx context.Context, in *PutRequest, opts ...grpc.CallOption) (*Empty, error)
+}
+
+type kVClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewKVClient(cc *grpc.ClientConn) KVClient {
+ return &kVClient{cc}
+}
+
+func (c *kVClient) Get(ctx context.Context, in *GetRequest, opts ...grpc.CallOption) (*GetResponse, error) {
+ out := new(GetResponse)
+ err := grpc.Invoke(ctx, "/proto.KV/Get", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *kVClient) Put(ctx context.Context, in *PutRequest, opts ...grpc.CallOption) (*Empty, error) {
+ out := new(Empty)
+ err := grpc.Invoke(ctx, "/proto.KV/Put", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// Server API for KV service
+
+type KVServer interface {
+ Get(context.Context, *GetRequest) (*GetResponse, error)
+ Put(context.Context, *PutRequest) (*Empty, error)
+}
+
+func RegisterKVServer(s *grpc.Server, srv KVServer) {
+ s.RegisterService(&_KV_serviceDesc, srv)
+}
+
+func _KV_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(GetRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(KVServer).Get(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/proto.KV/Get",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(KVServer).Get(ctx, req.(*GetRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _KV_Put_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(PutRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(KVServer).Put(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/proto.KV/Put",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(KVServer).Put(ctx, req.(*PutRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+var _KV_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "proto.KV",
+ HandlerType: (*KVServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "Get",
+ Handler: _KV_Get_Handler,
+ },
+ {
+ MethodName: "Put",
+ Handler: _KV_Put_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{},
+ Metadata: "kv.proto",
+}
+
+func init() { proto1.RegisterFile("kv.proto", fileDescriptor0) }
+
+var fileDescriptor0 = []byte{
+ // 162 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xc8, 0x2e, 0xd3, 0x2b,
+ 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x4a, 0x72, 0x5c, 0x5c, 0xee, 0xa9, 0x25, 0x41,
+ 0xa9, 0x85, 0xa5, 0xa9, 0xc5, 0x25, 0x42, 0x02, 0x5c, 0xcc, 0xd9, 0xa9, 0x95, 0x12, 0x8c, 0x0a,
+ 0x8c, 0x1a, 0x9c, 0x41, 0x20, 0xa6, 0x92, 0x32, 0x17, 0x37, 0x58, 0xbe, 0xb8, 0x20, 0x3f, 0xaf,
+ 0x38, 0x55, 0x48, 0x84, 0x8b, 0xb5, 0x2c, 0x31, 0xa7, 0x34, 0x15, 0xac, 0x84, 0x27, 0x08, 0xc2,
+ 0x51, 0x32, 0xe1, 0xe2, 0x0a, 0x28, 0xc5, 0x6d, 0x08, 0x42, 0x17, 0x13, 0xb2, 0x2e, 0x76, 0x2e,
+ 0x56, 0xd7, 0xdc, 0x82, 0x92, 0x4a, 0xa3, 0x28, 0x2e, 0x26, 0xef, 0x30, 0x21, 0x1d, 0x2e, 0x66,
+ 0xf7, 0xd4, 0x12, 0x21, 0x41, 0x88, 0xfb, 0xf4, 0x10, 0xae, 0x92, 0x12, 0x42, 0x16, 0x82, 0x3a,
+ 0x44, 0x8d, 0x8b, 0x39, 0xa0, 0x14, 0xa1, 0x1a, 0x61, 0xbd, 0x14, 0x0f, 0x54, 0x08, 0x6c, 0x76,
+ 0x12, 0x1b, 0x98, 0x63, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x06, 0x32, 0x05, 0x89, 0xf9, 0x00,
+ 0x00, 0x00,
+}
--- /dev/null
+syntax = "proto3";
+package proto;
+
+message GetRequest {
+ string key = 1;
+}
+
+message GetResponse {
+ bytes value = 1;
+}
+
+message PutRequest {
+ string key = 1;
+ bytes value = 2;
+}
+
+message Empty {}
+
+service KV {
+ rpc Get(GetRequest) returns (GetResponse);
+ rpc Put(PutRequest) returns (Empty);
+}
--- /dev/null
+package shared
+
+import (
+ "github.com/hashicorp/go-plugin/examples/grpc/proto"
+ "golang.org/x/net/context"
+)
+
+// GRPCClient is an implementation of KV that talks over RPC.
+type GRPCClient struct{ client proto.KVClient }
+
+func (m *GRPCClient) Put(key string, value []byte) error {
+ _, err := m.client.Put(context.Background(), &proto.PutRequest{
+ Key: key,
+ Value: value,
+ })
+ return err
+}
+
+func (m *GRPCClient) Get(key string) ([]byte, error) {
+ resp, err := m.client.Get(context.Background(), &proto.GetRequest{
+ Key: key,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return resp.Value, nil
+}
+
+// Here is the gRPC server that GRPCClient talks to.
+type GRPCServer struct {
+ // This is the real implementation
+ Impl KV
+}
+
+func (m *GRPCServer) Put(
+ ctx context.Context,
+ req *proto.PutRequest) (*proto.Empty, error) {
+ return &proto.Empty{}, m.Impl.Put(req.Key, req.Value)
+}
+
+func (m *GRPCServer) Get(
+ ctx context.Context,
+ req *proto.GetRequest) (*proto.GetResponse, error) {
+ v, err := m.Impl.Get(req.Key)
+ return &proto.GetResponse{Value: v}, err
+}
--- /dev/null
+// Package shared contains shared data between the host and plugins.
+package shared
+
+import (
+ "context"
+ "net/rpc"
+
+ "google.golang.org/grpc"
+
+ "github.com/hashicorp/go-plugin"
+ "github.com/hashicorp/go-plugin/examples/grpc/proto"
+)
+
+// Handshake is a common handshake that is shared by plugin and host.
+var Handshake = plugin.HandshakeConfig{
+ ProtocolVersion: 1,
+ MagicCookieKey: "BASIC_PLUGIN",
+ MagicCookieValue: "hello",
+}
+
+// PluginMap is the map of plugins we can dispense.
+var PluginMap = map[string]plugin.Plugin{
+ "kv": &KVPlugin{},
+}
+
+// KV is the interface that we're exposing as a plugin.
+type KV interface {
+ Put(key string, value []byte) error
+ Get(key string) ([]byte, error)
+}
+
+// This is the implementation of plugin.Plugin so we can serve/consume this.
+// We also implement GRPCPlugin so that this plugin can be served over
+// gRPC.
+type KVPlugin struct {
+ // Concrete implementation, written in Go. This is only used for plugins
+ // that are written in Go.
+ Impl KV
+}
+
+func (p *KVPlugin) Server(*plugin.MuxBroker) (interface{}, error) {
+ return &RPCServer{Impl: p.Impl}, nil
+}
+
+func (*KVPlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
+ return &RPCClient{client: c}, nil
+}
+
+func (p *KVPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {
+ proto.RegisterKVServer(s, &GRPCServer{Impl: p.Impl})
+ return nil
+}
+
+func (p *KVPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
+ return &GRPCClient{client: proto.NewKVClient(c)}, nil
+}
--- /dev/null
+package shared
+
+import (
+ "net/rpc"
+)
+
+// RPCClient is an implementation of KV that talks over RPC.
+type RPCClient struct{ client *rpc.Client }
+
+func (m *RPCClient) Put(key string, value []byte) error {
+ // We don't expect a response, so we can just use interface{}
+ var resp interface{}
+
+ // The args are just going to be a map. A struct could be better.
+ return m.client.Call("Plugin.Put", map[string]interface{}{
+ "key": key,
+ "value": value,
+ }, &resp)
+}
+
+func (m *RPCClient) Get(key string) ([]byte, error) {
+ var resp []byte
+ err := m.client.Call("Plugin.Get", key, &resp)
+ return resp, err
+}
+
+// Here is the RPC server that RPCClient talks to, conforming to
+// the requirements of net/rpc
+type RPCServer struct {
+ // This is the real implementation
+ Impl KV
+}
+
+func (m *RPCServer) Put(args map[string]interface{}, resp *interface{}) error {
+ return m.Impl.Put(args["key"].(string), args["value"].([]byte))
+}
+
+func (m *RPCServer) Get(key string, resp *[]byte) error {
+ v, err := m.Impl.Get(key)
+ *resp = v
+ return err
+}
--- /dev/null
+package plugin
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/oklog/run"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+)
+
+// streamer interface is used in the broker to send/receive connection
+// information.
+type streamer interface {
+ Send(*ConnInfo) error
+ Recv() (*ConnInfo, error)
+ Close()
+}
+
+// sendErr is used to pass errors back during a send.
+type sendErr struct {
+ i *ConnInfo
+ ch chan error
+}
+
+// gRPCBrokerServer is used by the plugin to start a stream and to send
+// connection information to/from the plugin. Implements GRPCBrokerServer and
+// streamer interfaces.
+type gRPCBrokerServer struct {
+ // send is used to send connection info to the gRPC stream.
+ send chan *sendErr
+
+ // recv is used to receive connection info from the gRPC stream.
+ recv chan *ConnInfo
+
+ // quit closes down the stream.
+ quit chan struct{}
+
+ // o is used to ensure we close the quit channel only once.
+ o sync.Once
+}
+
+func newGRPCBrokerServer() *gRPCBrokerServer {
+ return &gRPCBrokerServer{
+ send: make(chan *sendErr),
+ recv: make(chan *ConnInfo),
+ quit: make(chan struct{}),
+ }
+}
+
+// StartStream implements the GRPCBrokerServer interface and will block until
+// the quit channel is closed or the context reports Done. The stream will pass
+// connection information to/from the client.
+func (s *gRPCBrokerServer) StartStream(stream GRPCBroker_StartStreamServer) error {
+ doneCh := stream.Context().Done()
+ defer s.Close()
+
+ // Proccess send stream
+ go func() {
+ for {
+ select {
+ case <-doneCh:
+ return
+ case <-s.quit:
+ return
+ case se := <-s.send:
+ err := stream.Send(se.i)
+ se.ch <- err
+ }
+ }
+ }()
+
+ // Process receive stream
+ for {
+ i, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+ select {
+ case <-doneCh:
+ return nil
+ case <-s.quit:
+ return nil
+ case s.recv <- i:
+ }
+ }
+
+ return nil
+}
+
+// Send is used by the GRPCBroker to pass connection information into the stream
+// to the client.
+func (s *gRPCBrokerServer) Send(i *ConnInfo) error {
+ ch := make(chan error)
+ defer close(ch)
+
+ select {
+ case <-s.quit:
+ return errors.New("broker closed")
+ case s.send <- &sendErr{
+ i: i,
+ ch: ch,
+ }:
+ }
+
+ return <-ch
+}
+
+// Recv is used by the GRPCBroker to pass connection information that has been
+// sent from the client from the stream to the broker.
+func (s *gRPCBrokerServer) Recv() (*ConnInfo, error) {
+ select {
+ case <-s.quit:
+ return nil, errors.New("broker closed")
+ case i := <-s.recv:
+ return i, nil
+ }
+}
+
+// Close closes the quit channel, shutting down the stream.
+func (s *gRPCBrokerServer) Close() {
+ s.o.Do(func() {
+ close(s.quit)
+ })
+}
+
+// gRPCBrokerClientImpl is used by the client to start a stream and to send
+// connection information to/from the client. Implements GRPCBrokerClient and
+// streamer interfaces.
+type gRPCBrokerClientImpl struct {
+ // client is the underlying GRPC client used to make calls to the server.
+ client GRPCBrokerClient
+
+ // send is used to send connection info to the gRPC stream.
+ send chan *sendErr
+
+ // recv is used to receive connection info from the gRPC stream.
+ recv chan *ConnInfo
+
+ // quit closes down the stream.
+ quit chan struct{}
+
+ // o is used to ensure we close the quit channel only once.
+ o sync.Once
+}
+
+func newGRPCBrokerClient(conn *grpc.ClientConn) *gRPCBrokerClientImpl {
+ return &gRPCBrokerClientImpl{
+ client: NewGRPCBrokerClient(conn),
+ send: make(chan *sendErr),
+ recv: make(chan *ConnInfo),
+ quit: make(chan struct{}),
+ }
+}
+
+// StartStream implements the GRPCBrokerClient interface and will block until
+// the quit channel is closed or the context reports Done. The stream will pass
+// connection information to/from the plugin.
+func (s *gRPCBrokerClientImpl) StartStream() error {
+ ctx, cancelFunc := context.WithCancel(context.Background())
+ defer cancelFunc()
+ defer s.Close()
+
+ stream, err := s.client.StartStream(ctx)
+ if err != nil {
+ return err
+ }
+ doneCh := stream.Context().Done()
+
+ go func() {
+ for {
+ select {
+ case <-doneCh:
+ return
+ case <-s.quit:
+ return
+ case se := <-s.send:
+ err := stream.Send(se.i)
+ se.ch <- err
+ }
+ }
+ }()
+
+ for {
+ i, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+ select {
+ case <-doneCh:
+ return nil
+ case <-s.quit:
+ return nil
+ case s.recv <- i:
+ }
+ }
+
+ return nil
+}
+
+// Send is used by the GRPCBroker to pass connection information into the stream
+// to the plugin.
+func (s *gRPCBrokerClientImpl) Send(i *ConnInfo) error {
+ ch := make(chan error)
+ defer close(ch)
+
+ select {
+ case <-s.quit:
+ return errors.New("broker closed")
+ case s.send <- &sendErr{
+ i: i,
+ ch: ch,
+ }:
+ }
+
+ return <-ch
+}
+
+// Recv is used by the GRPCBroker to pass connection information that has been
+// sent from the plugin to the broker.
+func (s *gRPCBrokerClientImpl) Recv() (*ConnInfo, error) {
+ select {
+ case <-s.quit:
+ return nil, errors.New("broker closed")
+ case i := <-s.recv:
+ return i, nil
+ }
+}
+
+// Close closes the quit channel, shutting down the stream.
+func (s *gRPCBrokerClientImpl) Close() {
+ s.o.Do(func() {
+ close(s.quit)
+ })
+}
+
+// GRPCBroker is responsible for brokering connections by unique ID.
+//
+// It is used by plugins to create multiple gRPC connections and data
+// streams between the plugin process and the host process.
+//
+// This allows a plugin to request a channel with a specific ID to connect to
+// or accept a connection from, and the broker handles the details of
+// holding these channels open while they're being negotiated.
+//
+// The Plugin interface has access to these for both Server and Client.
+// The broker can be used by either (optionally) to reserve and connect to
+// new streams. This is useful for complex args and return values,
+// or anything else you might need a data stream for.
+type GRPCBroker struct {
+ nextId uint32
+ streamer streamer
+ streams map[uint32]*gRPCBrokerPending
+ tls *tls.Config
+ doneCh chan struct{}
+ o sync.Once
+
+ sync.Mutex
+}
+
+type gRPCBrokerPending struct {
+ ch chan *ConnInfo
+ doneCh chan struct{}
+}
+
+func newGRPCBroker(s streamer, tls *tls.Config) *GRPCBroker {
+ return &GRPCBroker{
+ streamer: s,
+ streams: make(map[uint32]*gRPCBrokerPending),
+ tls: tls,
+ doneCh: make(chan struct{}),
+ }
+}
+
+// Accept accepts a connection by ID.
+//
+// This should not be called multiple times with the same ID at one time.
+func (b *GRPCBroker) Accept(id uint32) (net.Listener, error) {
+ listener, err := serverListener()
+ if err != nil {
+ return nil, err
+ }
+
+ err = b.streamer.Send(&ConnInfo{
+ ServiceId: id,
+ Network: listener.Addr().Network(),
+ Address: listener.Addr().String(),
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return listener, nil
+}
+
+// AcceptAndServe is used to accept a specific stream ID and immediately
+// serve a gRPC server on that stream ID. This is used to easily serve
+// complex arguments. Each AcceptAndServe call opens a new listener socket and
+// sends the connection info down the stream to the dialer. Since a new
+// connection is opened every call, these calls should be used sparingly.
+// Multiple gRPC server implementations can be registered to a single
+// AcceptAndServe call.
+func (b *GRPCBroker) AcceptAndServe(id uint32, s func([]grpc.ServerOption) *grpc.Server) {
+ listener, err := b.Accept(id)
+ if err != nil {
+ log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err)
+ return
+ }
+ defer listener.Close()
+
+ var opts []grpc.ServerOption
+ if b.tls != nil {
+ opts = []grpc.ServerOption{grpc.Creds(credentials.NewTLS(b.tls))}
+ }
+
+ server := s(opts)
+
+ // Here we use a run group to close this goroutine if the server is shutdown
+ // or the broker is shutdown.
+ var g run.Group
+ {
+ // Serve on the listener, if shutting down call GracefulStop.
+ g.Add(func() error {
+ return server.Serve(listener)
+ }, func(err error) {
+ server.GracefulStop()
+ })
+ }
+ {
+ // block on the closeCh or the doneCh. If we are shutting down close the
+ // closeCh.
+ closeCh := make(chan struct{})
+ g.Add(func() error {
+ select {
+ case <-b.doneCh:
+ case <-closeCh:
+ }
+ return nil
+ }, func(err error) {
+ close(closeCh)
+ })
+ }
+
+ // Block until we are done
+ g.Run()
+}
+
+// Close closes the stream and all servers.
+func (b *GRPCBroker) Close() error {
+ b.streamer.Close()
+ b.o.Do(func() {
+ close(b.doneCh)
+ })
+ return nil
+}
+
+// Dial opens a connection by ID.
+func (b *GRPCBroker) Dial(id uint32) (conn *grpc.ClientConn, err error) {
+ var c *ConnInfo
+
+ // Open the stream
+ p := b.getStream(id)
+ select {
+ case c = <-p.ch:
+ close(p.doneCh)
+ case <-time.After(5 * time.Second):
+ return nil, fmt.Errorf("timeout waiting for connection info")
+ }
+
+ var addr net.Addr
+ switch c.Network {
+ case "tcp":
+ addr, err = net.ResolveTCPAddr("tcp", c.Address)
+ case "unix":
+ addr, err = net.ResolveUnixAddr("unix", c.Address)
+ default:
+ err = fmt.Errorf("Unknown address type: %s", c.Address)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ return dialGRPCConn(b.tls, netAddrDialer(addr))
+}
+
+// NextId returns a unique ID to use next.
+//
+// It is possible for very long-running plugin hosts to wrap this value,
+// though it would require a very large amount of calls. In practice
+// we've never seen it happen.
+func (m *GRPCBroker) NextId() uint32 {
+ return atomic.AddUint32(&m.nextId, 1)
+}
+
+// Run starts the brokering and should be executed in a goroutine, since it
+// blocks forever, or until the session closes.
+//
+// Uses of GRPCBroker never need to call this. It is called internally by
+// the plugin host/client.
+func (m *GRPCBroker) Run() {
+ for {
+ stream, err := m.streamer.Recv()
+ if err != nil {
+ // Once we receive an error, just exit
+ break
+ }
+
+ // Initialize the waiter
+ p := m.getStream(stream.ServiceId)
+ select {
+ case p.ch <- stream:
+ default:
+ }
+
+ go m.timeoutWait(stream.ServiceId, p)
+ }
+}
+
+func (m *GRPCBroker) getStream(id uint32) *gRPCBrokerPending {
+ m.Lock()
+ defer m.Unlock()
+
+ p, ok := m.streams[id]
+ if ok {
+ return p
+ }
+
+ m.streams[id] = &gRPCBrokerPending{
+ ch: make(chan *ConnInfo, 1),
+ doneCh: make(chan struct{}),
+ }
+ return m.streams[id]
+}
+
+func (m *GRPCBroker) timeoutWait(id uint32, p *gRPCBrokerPending) {
+ // Wait for the stream to either be picked up and connected, or
+ // for a timeout.
+ select {
+ case <-p.doneCh:
+ case <-time.After(5 * time.Second):
+ }
+
+ m.Lock()
+ defer m.Unlock()
+
+ // Delete the stream so no one else can grab it
+ delete(m.streams, id)
+}
--- /dev/null
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: grpc_broker.proto
+
+/*
+Package plugin is a generated protocol buffer package.
+
+It is generated from these files:
+ grpc_broker.proto
+
+It has these top-level messages:
+ ConnInfo
+*/
+package plugin
+
+import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+import (
+ context "golang.org/x/net/context"
+ grpc "google.golang.org/grpc"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
+
+type ConnInfo struct {
+ ServiceId uint32 `protobuf:"varint,1,opt,name=service_id,json=serviceId" json:"service_id,omitempty"`
+ Network string `protobuf:"bytes,2,opt,name=network" json:"network,omitempty"`
+ Address string `protobuf:"bytes,3,opt,name=address" json:"address,omitempty"`
+}
+
+func (m *ConnInfo) Reset() { *m = ConnInfo{} }
+func (m *ConnInfo) String() string { return proto.CompactTextString(m) }
+func (*ConnInfo) ProtoMessage() {}
+func (*ConnInfo) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+func (m *ConnInfo) GetServiceId() uint32 {
+ if m != nil {
+ return m.ServiceId
+ }
+ return 0
+}
+
+func (m *ConnInfo) GetNetwork() string {
+ if m != nil {
+ return m.Network
+ }
+ return ""
+}
+
+func (m *ConnInfo) GetAddress() string {
+ if m != nil {
+ return m.Address
+ }
+ return ""
+}
+
+func init() {
+ proto.RegisterType((*ConnInfo)(nil), "plugin.ConnInfo")
+}
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+const _ = grpc.SupportPackageIsVersion4
+
+// Client API for GRPCBroker service
+
+type GRPCBrokerClient interface {
+ StartStream(ctx context.Context, opts ...grpc.CallOption) (GRPCBroker_StartStreamClient, error)
+}
+
+type gRPCBrokerClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewGRPCBrokerClient(cc *grpc.ClientConn) GRPCBrokerClient {
+ return &gRPCBrokerClient{cc}
+}
+
+func (c *gRPCBrokerClient) StartStream(ctx context.Context, opts ...grpc.CallOption) (GRPCBroker_StartStreamClient, error) {
+ stream, err := grpc.NewClientStream(ctx, &_GRPCBroker_serviceDesc.Streams[0], c.cc, "/plugin.GRPCBroker/StartStream", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &gRPCBrokerStartStreamClient{stream}
+ return x, nil
+}
+
+type GRPCBroker_StartStreamClient interface {
+ Send(*ConnInfo) error
+ Recv() (*ConnInfo, error)
+ grpc.ClientStream
+}
+
+type gRPCBrokerStartStreamClient struct {
+ grpc.ClientStream
+}
+
+func (x *gRPCBrokerStartStreamClient) Send(m *ConnInfo) error {
+ return x.ClientStream.SendMsg(m)
+}
+
+func (x *gRPCBrokerStartStreamClient) Recv() (*ConnInfo, error) {
+ m := new(ConnInfo)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// Server API for GRPCBroker service
+
+type GRPCBrokerServer interface {
+ StartStream(GRPCBroker_StartStreamServer) error
+}
+
+func RegisterGRPCBrokerServer(s *grpc.Server, srv GRPCBrokerServer) {
+ s.RegisterService(&_GRPCBroker_serviceDesc, srv)
+}
+
+func _GRPCBroker_StartStream_Handler(srv interface{}, stream grpc.ServerStream) error {
+ return srv.(GRPCBrokerServer).StartStream(&gRPCBrokerStartStreamServer{stream})
+}
+
+type GRPCBroker_StartStreamServer interface {
+ Send(*ConnInfo) error
+ Recv() (*ConnInfo, error)
+ grpc.ServerStream
+}
+
+type gRPCBrokerStartStreamServer struct {
+ grpc.ServerStream
+}
+
+func (x *gRPCBrokerStartStreamServer) Send(m *ConnInfo) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+func (x *gRPCBrokerStartStreamServer) Recv() (*ConnInfo, error) {
+ m := new(ConnInfo)
+ if err := x.ServerStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+var _GRPCBroker_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "plugin.GRPCBroker",
+ HandlerType: (*GRPCBrokerServer)(nil),
+ Methods: []grpc.MethodDesc{},
+ Streams: []grpc.StreamDesc{
+ {
+ StreamName: "StartStream",
+ Handler: _GRPCBroker_StartStream_Handler,
+ ServerStreams: true,
+ ClientStreams: true,
+ },
+ },
+ Metadata: "grpc_broker.proto",
+}
+
+func init() { proto.RegisterFile("grpc_broker.proto", fileDescriptor0) }
+
+var fileDescriptor0 = []byte{
+ // 170 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4c, 0x2f, 0x2a, 0x48,
+ 0x8e, 0x4f, 0x2a, 0xca, 0xcf, 0x4e, 0x2d, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b,
+ 0xc8, 0x29, 0x4d, 0xcf, 0xcc, 0x53, 0x8a, 0xe5, 0xe2, 0x70, 0xce, 0xcf, 0xcb, 0xf3, 0xcc, 0x4b,
+ 0xcb, 0x17, 0x92, 0xe5, 0xe2, 0x2a, 0x4e, 0x2d, 0x2a, 0xcb, 0x4c, 0x4e, 0x8d, 0xcf, 0x4c, 0x91,
+ 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0d, 0xe2, 0x84, 0x8a, 0x78, 0xa6, 0x08, 0x49, 0x70, 0xb1, 0xe7,
+ 0xa5, 0x96, 0x94, 0xe7, 0x17, 0x65, 0x4b, 0x30, 0x29, 0x30, 0x6a, 0x70, 0x06, 0xc1, 0xb8, 0x20,
+ 0x99, 0xc4, 0x94, 0x94, 0xa2, 0xd4, 0xe2, 0x62, 0x09, 0x66, 0x88, 0x0c, 0x94, 0x6b, 0xe4, 0xcc,
+ 0xc5, 0xe5, 0x1e, 0x14, 0xe0, 0xec, 0x04, 0xb6, 0x5a, 0xc8, 0x94, 0x8b, 0x3b, 0xb8, 0x24, 0xb1,
+ 0xa8, 0x24, 0xb8, 0xa4, 0x28, 0x35, 0x31, 0x57, 0x48, 0x40, 0x0f, 0xe2, 0x08, 0x3d, 0x98, 0x0b,
+ 0xa4, 0x30, 0x44, 0x34, 0x18, 0x0d, 0x18, 0x93, 0xd8, 0xc0, 0x4e, 0x36, 0x06, 0x04, 0x00, 0x00,
+ 0xff, 0xff, 0x7b, 0x5d, 0xfb, 0xe1, 0xc7, 0x00, 0x00, 0x00,
+}
--- /dev/null
+syntax = "proto3";
+package plugin;
+
+message ConnInfo {
+ uint32 service_id = 1;
+ string network = 2;
+ string address = 3;
+}
+
+service GRPCBroker {
+ rpc StartStream(stream ConnInfo) returns (stream ConnInfo);
+}
+
+
--- /dev/null
+package plugin
+
+import (
+ "crypto/tls"
+ "fmt"
+ "net"
+ "time"
+
+ "golang.org/x/net/context"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+func dialGRPCConn(tls *tls.Config, dialer func(string, time.Duration) (net.Conn, error)) (*grpc.ClientConn, error) {
+ // Build dialing options.
+ opts := make([]grpc.DialOption, 0, 5)
+
+ // We use a custom dialer so that we can connect over unix domain sockets
+ opts = append(opts, grpc.WithDialer(dialer))
+
+ // go-plugin expects to block the connection
+ opts = append(opts, grpc.WithBlock())
+
+ // Fail right away
+ opts = append(opts, grpc.FailOnNonTempDialError(true))
+
+ // If we have no TLS configuration set, we need to explicitly tell grpc
+ // that we're connecting with an insecure connection.
+ if tls == nil {
+ opts = append(opts, grpc.WithInsecure())
+ } else {
+ opts = append(opts, grpc.WithTransportCredentials(
+ credentials.NewTLS(tls)))
+ }
+
+ // Connect. Note the first parameter is unused because we use a custom
+ // dialer that has the state to see the address.
+ conn, err := grpc.Dial("unused", opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ return conn, nil
+}
+
+// newGRPCClient creates a new GRPCClient. The Client argument is expected
+// to be successfully started already with a lock held.
+func newGRPCClient(doneCtx context.Context, c *Client) (*GRPCClient, error) {
+ conn, err := dialGRPCConn(c.config.TLSConfig, c.dialer)
+ if err != nil {
+ return nil, err
+ }
+
+ // Start the broker.
+ brokerGRPCClient := newGRPCBrokerClient(conn)
+ broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig)
+ go broker.Run()
+ go brokerGRPCClient.StartStream()
+
+ return &GRPCClient{
+ Conn: conn,
+ Plugins: c.config.Plugins,
+ doneCtx: doneCtx,
+ broker: broker,
+ }, nil
+}
+
+// GRPCClient connects to a GRPCServer over gRPC to dispense plugin types.
+type GRPCClient struct {
+ Conn *grpc.ClientConn
+ Plugins map[string]Plugin
+
+ doneCtx context.Context
+ broker *GRPCBroker
+}
+
+// ClientProtocol impl.
+func (c *GRPCClient) Close() error {
+ c.broker.Close()
+ return c.Conn.Close()
+}
+
+// ClientProtocol impl.
+func (c *GRPCClient) Dispense(name string) (interface{}, error) {
+ raw, ok := c.Plugins[name]
+ if !ok {
+ return nil, fmt.Errorf("unknown plugin type: %s", name)
+ }
+
+ p, ok := raw.(GRPCPlugin)
+ if !ok {
+ return nil, fmt.Errorf("plugin %q doesn't support gRPC", name)
+ }
+
+ return p.GRPCClient(c.doneCtx, c.broker, c.Conn)
+}
+
+// ClientProtocol impl.
+func (c *GRPCClient) Ping() error {
+ client := grpc_health_v1.NewHealthClient(c.Conn)
+ _, err := client.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{
+ Service: GRPCServiceName,
+ })
+
+ return err
+}
--- /dev/null
+package plugin
+
+import (
+ "context"
+ "reflect"
+ "testing"
+
+ "github.com/hashicorp/go-plugin/test/grpc"
+ "google.golang.org/grpc"
+)
+
+func TestGRPCClient_App(t *testing.T) {
+ client, server := TestPluginGRPCConn(t, map[string]Plugin{
+ "test": new(testInterfacePlugin),
+ })
+ defer client.Close()
+ defer server.Stop()
+
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+
+ err = impl.Bidirectional()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestGRPCConn_BidirectionalPing(t *testing.T) {
+ conn, _ := TestGRPCConn(t, func(s *grpc.Server) {
+ grpctest.RegisterPingPongServer(s, &pingPongServer{})
+ })
+ defer conn.Close()
+ pingPongClient := grpctest.NewPingPongClient(conn)
+
+ pResp, err := pingPongClient.Ping(context.Background(), &grpctest.PingRequest{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if pResp.Msg != "pong" {
+ t.Fatal("Bad PingPong")
+ }
+}
+
+func TestGRPCC_Stream(t *testing.T) {
+ client, server := TestPluginGRPCConn(t, map[string]Plugin{
+ "test": new(testInterfacePlugin),
+ })
+ defer client.Close()
+ defer server.Stop()
+
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ impl, ok := raw.(testStreamer)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ expected := []int32{21, 22, 23, 24, 25, 26}
+ result, err := impl.Stream(21, 27)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !reflect.DeepEqual(result, expected) {
+ t.Fatalf("expected: %v\ngot: %v", expected, result)
+ }
+}
+
+func TestGRPCClient_Ping(t *testing.T) {
+ client, server := TestPluginGRPCConn(t, map[string]Plugin{
+ "test": new(testInterfacePlugin),
+ })
+ defer client.Close()
+ defer server.Stop()
+
+ // Run a couple pings
+ if err := client.Ping(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ if err := client.Ping(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Close the remote end
+ server.server.Stop()
+
+ // Test ping fails
+ if err := client.Ping(); err == nil {
+ t.Fatal("should error")
+ }
+}
--- /dev/null
+package plugin
+
+import (
+ "bytes"
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/health"
+ "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+// GRPCServiceName is the name of the service that the health check should
+// return as passing.
+const GRPCServiceName = "plugin"
+
+// DefaultGRPCServer can be used with the "GRPCServer" field for Server
+// as a default factory method to create a gRPC server with no extra options.
+func DefaultGRPCServer(opts []grpc.ServerOption) *grpc.Server {
+ return grpc.NewServer(opts...)
+}
+
+// GRPCServer is a ServerType implementation that serves plugins over
+// gRPC. This allows plugins to easily be written for other languages.
+//
+// The GRPCServer outputs a custom configuration as a base64-encoded
+// JSON structure represented by the GRPCServerConfig config structure.
+type GRPCServer struct {
+ // Plugins are the list of plugins to serve.
+ Plugins map[string]Plugin
+
+ // Server is the actual server that will accept connections. This
+ // will be used for plugin registration as well.
+ Server func([]grpc.ServerOption) *grpc.Server
+
+ // TLS should be the TLS configuration if available. If this is nil,
+ // the connection will not have transport security.
+ TLS *tls.Config
+
+ // DoneCh is the channel that is closed when this server has exited.
+ DoneCh chan struct{}
+
+ // Stdout/StderrLis are the readers for stdout/stderr that will be copied
+ // to the stdout/stderr connection that is output.
+ Stdout io.Reader
+ Stderr io.Reader
+
+ config GRPCServerConfig
+ server *grpc.Server
+ broker *GRPCBroker
+}
+
+// ServerProtocol impl.
+func (s *GRPCServer) Init() error {
+ // Create our server
+ var opts []grpc.ServerOption
+ if s.TLS != nil {
+ opts = append(opts, grpc.Creds(credentials.NewTLS(s.TLS)))
+ }
+ s.server = s.Server(opts)
+
+ // Register the health service
+ healthCheck := health.NewServer()
+ healthCheck.SetServingStatus(
+ GRPCServiceName, grpc_health_v1.HealthCheckResponse_SERVING)
+ grpc_health_v1.RegisterHealthServer(s.server, healthCheck)
+
+ // Register the broker service
+ brokerServer := newGRPCBrokerServer()
+ RegisterGRPCBrokerServer(s.server, brokerServer)
+ s.broker = newGRPCBroker(brokerServer, s.TLS)
+ go s.broker.Run()
+
+ // Register all our plugins onto the gRPC server.
+ for k, raw := range s.Plugins {
+ p, ok := raw.(GRPCPlugin)
+ if !ok {
+ return fmt.Errorf("%q is not a GRPC-compatible plugin", k)
+ }
+
+ if err := p.GRPCServer(s.broker, s.server); err != nil {
+ return fmt.Errorf("error registring %q: %s", k, err)
+ }
+ }
+
+ return nil
+}
+
+// Stop calls Stop on the underlying grpc.Server
+func (s *GRPCServer) Stop() {
+ s.server.Stop()
+}
+
+// GracefulStop calls GracefulStop on the underlying grpc.Server
+func (s *GRPCServer) GracefulStop() {
+ s.server.GracefulStop()
+}
+
+// Config is the GRPCServerConfig encoded as JSON then base64.
+func (s *GRPCServer) Config() string {
+ // Create a buffer that will contain our final contents
+ var buf bytes.Buffer
+
+ // Wrap the base64 encoding with JSON encoding.
+ if err := json.NewEncoder(&buf).Encode(s.config); err != nil {
+ // We panic since ths shouldn't happen under any scenario. We
+ // carefully control the structure being encoded here and it should
+ // always be successful.
+ panic(err)
+ }
+
+ return buf.String()
+}
+
+func (s *GRPCServer) Serve(lis net.Listener) {
+ // Start serving in a goroutine
+ go s.server.Serve(lis)
+
+ // Wait until graceful completion
+ <-s.DoneCh
+}
+
+// GRPCServerConfig is the extra configuration passed along for consumers
+// to facilitate using GRPC plugins.
+type GRPCServerConfig struct {
+ StdoutAddr string `json:"stdout_addr"`
+ StderrAddr string `json:"stderr_addr"`
+}
--- /dev/null
+package plugin
+
+import (
+ "encoding/json"
+ "time"
+)
+
+// logEntry is the JSON payload that gets sent to Stderr from the plugin to the host
+type logEntry struct {
+ Message string `json:"@message"`
+ Level string `json:"@level"`
+ Timestamp time.Time `json:"timestamp"`
+ KVPairs []*logEntryKV `json:"kv_pairs"`
+}
+
+// logEntryKV is a key value pair within the Output payload
+type logEntryKV struct {
+ Key string `json:"key"`
+ Value interface{} `json:"value"`
+}
+
+// flattenKVPairs is used to flatten KVPair slice into []interface{}
+// for hclog consumption.
+func flattenKVPairs(kvs []*logEntryKV) []interface{} {
+ var result []interface{}
+ for _, kv := range kvs {
+ result = append(result, kv.Key)
+ result = append(result, kv.Value)
+ }
+
+ return result
+}
+
+// parseJSON handles parsing JSON output
+func parseJSON(input string) (*logEntry, error) {
+ var raw map[string]interface{}
+ entry := &logEntry{}
+
+ err := json.Unmarshal([]byte(input), &raw)
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse hclog-specific objects
+ if v, ok := raw["@message"]; ok {
+ entry.Message = v.(string)
+ delete(raw, "@message")
+ }
+
+ if v, ok := raw["@level"]; ok {
+ entry.Level = v.(string)
+ delete(raw, "@level")
+ }
+
+ if v, ok := raw["@timestamp"]; ok {
+ t, err := time.Parse("2006-01-02T15:04:05.000000Z07:00", v.(string))
+ if err != nil {
+ return nil, err
+ }
+ entry.Timestamp = t
+ delete(raw, "@timestamp")
+ }
+
+ // Parse dynamic KV args from the hclog payload.
+ for k, v := range raw {
+ entry.KVPairs = append(entry.KVPairs, &logEntryKV{
+ Key: k,
+ Value: v,
+ })
+ }
+
+ return entry, nil
+}
--- /dev/null
+package plugin
+
+import (
+ "encoding/binary"
+ "fmt"
+ "log"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/hashicorp/yamux"
+)
+
+// MuxBroker is responsible for brokering multiplexed connections by unique ID.
+//
+// It is used by plugins to multiplex multiple RPC connections and data
+// streams on top of a single connection between the plugin process and the
+// host process.
+//
+// This allows a plugin to request a channel with a specific ID to connect to
+// or accept a connection from, and the broker handles the details of
+// holding these channels open while they're being negotiated.
+//
+// The Plugin interface has access to these for both Server and Client.
+// The broker can be used by either (optionally) to reserve and connect to
+// new multiplexed streams. This is useful for complex args and return values,
+// or anything else you might need a data stream for.
+type MuxBroker struct {
+ nextId uint32
+ session *yamux.Session
+ streams map[uint32]*muxBrokerPending
+
+ sync.Mutex
+}
+
+type muxBrokerPending struct {
+ ch chan net.Conn
+ doneCh chan struct{}
+}
+
+func newMuxBroker(s *yamux.Session) *MuxBroker {
+ return &MuxBroker{
+ session: s,
+ streams: make(map[uint32]*muxBrokerPending),
+ }
+}
+
+// Accept accepts a connection by ID.
+//
+// This should not be called multiple times with the same ID at one time.
+func (m *MuxBroker) Accept(id uint32) (net.Conn, error) {
+ var c net.Conn
+ p := m.getStream(id)
+ select {
+ case c = <-p.ch:
+ close(p.doneCh)
+ case <-time.After(5 * time.Second):
+ m.Lock()
+ defer m.Unlock()
+ delete(m.streams, id)
+
+ return nil, fmt.Errorf("timeout waiting for accept")
+ }
+
+ // Ack our connection
+ if err := binary.Write(c, binary.LittleEndian, id); err != nil {
+ c.Close()
+ return nil, err
+ }
+
+ return c, nil
+}
+
+// AcceptAndServe is used to accept a specific stream ID and immediately
+// serve an RPC server on that stream ID. This is used to easily serve
+// complex arguments.
+//
+// The served interface is always registered to the "Plugin" name.
+func (m *MuxBroker) AcceptAndServe(id uint32, v interface{}) {
+ conn, err := m.Accept(id)
+ if err != nil {
+ log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err)
+ return
+ }
+
+ serve(conn, "Plugin", v)
+}
+
+// Close closes the connection and all sub-connections.
+func (m *MuxBroker) Close() error {
+ return m.session.Close()
+}
+
+// Dial opens a connection by ID.
+func (m *MuxBroker) Dial(id uint32) (net.Conn, error) {
+ // Open the stream
+ stream, err := m.session.OpenStream()
+ if err != nil {
+ return nil, err
+ }
+
+ // Write the stream ID onto the wire.
+ if err := binary.Write(stream, binary.LittleEndian, id); err != nil {
+ stream.Close()
+ return nil, err
+ }
+
+ // Read the ack that we connected. Then we're off!
+ var ack uint32
+ if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil {
+ stream.Close()
+ return nil, err
+ }
+ if ack != id {
+ stream.Close()
+ return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id)
+ }
+
+ return stream, nil
+}
+
+// NextId returns a unique ID to use next.
+//
+// It is possible for very long-running plugin hosts to wrap this value,
+// though it would require a very large amount of RPC calls. In practice
+// we've never seen it happen.
+func (m *MuxBroker) NextId() uint32 {
+ return atomic.AddUint32(&m.nextId, 1)
+}
+
+// Run starts the brokering and should be executed in a goroutine, since it
+// blocks forever, or until the session closes.
+//
+// Uses of MuxBroker never need to call this. It is called internally by
+// the plugin host/client.
+func (m *MuxBroker) Run() {
+ for {
+ stream, err := m.session.AcceptStream()
+ if err != nil {
+ // Once we receive an error, just exit
+ break
+ }
+
+ // Read the stream ID from the stream
+ var id uint32
+ if err := binary.Read(stream, binary.LittleEndian, &id); err != nil {
+ stream.Close()
+ continue
+ }
+
+ // Initialize the waiter
+ p := m.getStream(id)
+ select {
+ case p.ch <- stream:
+ default:
+ }
+
+ // Wait for a timeout
+ go m.timeoutWait(id, p)
+ }
+}
+
+func (m *MuxBroker) getStream(id uint32) *muxBrokerPending {
+ m.Lock()
+ defer m.Unlock()
+
+ p, ok := m.streams[id]
+ if ok {
+ return p
+ }
+
+ m.streams[id] = &muxBrokerPending{
+ ch: make(chan net.Conn, 1),
+ doneCh: make(chan struct{}),
+ }
+ return m.streams[id]
+}
+
+func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) {
+ // Wait for the stream to either be picked up and connected, or
+ // for a timeout.
+ timeout := false
+ select {
+ case <-p.doneCh:
+ case <-time.After(5 * time.Second):
+ timeout = true
+ }
+
+ m.Lock()
+ defer m.Unlock()
+
+ // Delete the stream so no one else can grab it
+ delete(m.streams, id)
+
+ // If we timed out, then check if we have a channel in the buffer,
+ // and if so, close it.
+ if timeout {
+ select {
+ case s := <-p.ch:
+ s.Close()
+ }
+ }
+}
--- /dev/null
+// The plugin package exposes functions and helpers for communicating to
+// plugins which are implemented as standalone binary applications.
+//
+// plugin.Client fully manages the lifecycle of executing the application,
+// connecting to it, and returning the RPC client for dispensing plugins.
+//
+// plugin.Serve fully manages listeners to expose an RPC server from a binary
+// that plugin.Client can connect to.
+package plugin
+
+import (
+ "context"
+ "errors"
+ "net/rpc"
+
+ "google.golang.org/grpc"
+)
+
+// Plugin is the interface that is implemented to serve/connect to an
+// inteface implementation.
+type Plugin interface {
+ // Server should return the RPC server compatible struct to serve
+ // the methods that the Client calls over net/rpc.
+ Server(*MuxBroker) (interface{}, error)
+
+ // Client returns an interface implementation for the plugin you're
+ // serving that communicates to the server end of the plugin.
+ Client(*MuxBroker, *rpc.Client) (interface{}, error)
+}
+
+// GRPCPlugin is the interface that is implemented to serve/connect to
+// a plugin over gRPC.
+type GRPCPlugin interface {
+ // GRPCServer should register this plugin for serving with the
+ // given GRPCServer. Unlike Plugin.Server, this is only called once
+ // since gRPC plugins serve singletons.
+ GRPCServer(*GRPCBroker, *grpc.Server) error
+
+ // GRPCClient should return the interface implementation for the plugin
+ // you're serving via gRPC. The provided context will be canceled by
+ // go-plugin in the event of the plugin process exiting.
+ GRPCClient(context.Context, *GRPCBroker, *grpc.ClientConn) (interface{}, error)
+}
+
+// NetRPCUnsupportedPlugin implements Plugin but returns errors for the
+// Server and Client functions. This will effectively disable support for
+// net/rpc based plugins.
+//
+// This struct can be embedded in your struct.
+type NetRPCUnsupportedPlugin struct{}
+
+func (p NetRPCUnsupportedPlugin) Server(*MuxBroker) (interface{}, error) {
+ return nil, errors.New("net/rpc plugin protocol not supported")
+}
+
+func (p NetRPCUnsupportedPlugin) Client(*MuxBroker, *rpc.Client) (interface{}, error) {
+ return nil, errors.New("net/rpc plugin protocol not supported")
+}
--- /dev/null
+package plugin
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/rpc"
+ "os"
+ "os/exec"
+ "testing"
+ "time"
+
+ hclog "github.com/hashicorp/go-hclog"
+ "github.com/hashicorp/go-plugin/test/grpc"
+ "golang.org/x/net/context"
+ "google.golang.org/grpc"
+)
+
+// Test that NetRPCUnsupportedPlugin implements the correct interfaces.
+var _ Plugin = new(NetRPCUnsupportedPlugin)
+
+// testAPIVersion is the ProtocolVersion we use for testing.
+var testHandshake = HandshakeConfig{
+ ProtocolVersion: 1,
+ MagicCookieKey: "TEST_MAGIC_COOKIE",
+ MagicCookieValue: "test",
+}
+
+// testInterface is the test interface we use for plugins.
+type testInterface interface {
+ Double(int) int
+ PrintKV(string, interface{})
+ Bidirectional() error
+}
+
+// testStreamer is used to test the grpc streaming interface
+type testStreamer interface {
+ Stream(int32, int32) ([]int32, error)
+}
+
+// testInterfacePlugin is the implementation of Plugin to create
+// RPC client/server implementations for testInterface.
+type testInterfacePlugin struct {
+ Impl testInterface
+}
+
+func (p *testInterfacePlugin) Server(b *MuxBroker) (interface{}, error) {
+ return &testInterfaceServer{Impl: p.impl()}, nil
+}
+
+func (p *testInterfacePlugin) Client(b *MuxBroker, c *rpc.Client) (interface{}, error) {
+ return &testInterfaceClient{Client: c}, nil
+}
+
+func (p *testInterfacePlugin) GRPCServer(b *GRPCBroker, s *grpc.Server) error {
+ grpctest.RegisterTestServer(s, &testGRPCServer{broker: b, Impl: p.impl()})
+ return nil
+}
+
+func (p *testInterfacePlugin) GRPCClient(doneCtx context.Context, b *GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
+ return &testGRPCClient{broker: b, Client: grpctest.NewTestClient(c)}, nil
+}
+
+func (p *testInterfacePlugin) impl() testInterface {
+ if p.Impl != nil {
+ return p.Impl
+ }
+
+ return &testInterfaceImpl{
+ logger: hclog.New(&hclog.LoggerOptions{
+ Level: hclog.Trace,
+ Output: os.Stderr,
+ JSONFormat: true,
+ }),
+ }
+}
+
+// testInterfaceImpl implements testInterface concretely
+type testInterfaceImpl struct {
+ logger hclog.Logger
+}
+
+func (i *testInterfaceImpl) Double(v int) int { return v * 2 }
+
+func (i *testInterfaceImpl) PrintKV(key string, value interface{}) {
+ i.logger.Info("PrintKV called", key, value)
+}
+
+func (i *testInterfaceImpl) Bidirectional() error {
+ return nil
+}
+
+// testInterfaceClient implements testInterface to communicate over RPC
+type testInterfaceClient struct {
+ Client *rpc.Client
+}
+
+func (impl *testInterfaceClient) Double(v int) int {
+ var resp int
+ err := impl.Client.Call("Plugin.Double", v, &resp)
+ if err != nil {
+ panic(err)
+ }
+
+ return resp
+}
+
+func (impl *testInterfaceClient) PrintKV(key string, value interface{}) {
+ err := impl.Client.Call("Plugin.PrintKV", map[string]interface{}{
+ "key": key,
+ "value": value,
+ }, &struct{}{})
+ if err != nil {
+ panic(err)
+ }
+}
+
+func (impl *testInterfaceClient) Bidirectional() error {
+ return nil
+}
+
+// testInterfaceServer is the RPC server for testInterfaceClient
+type testInterfaceServer struct {
+ Broker *MuxBroker
+ Impl testInterface
+}
+
+func (s *testInterfaceServer) Double(arg int, resp *int) error {
+ *resp = s.Impl.Double(arg)
+ return nil
+}
+
+func (s *testInterfaceServer) PrintKV(args map[string]interface{}, _ *struct{}) error {
+ s.Impl.PrintKV(args["key"].(string), args["value"])
+ return nil
+}
+
+// testPluginMap can be used for tests as a plugin map
+var testPluginMap = map[string]Plugin{
+ "test": new(testInterfacePlugin),
+}
+
+// testGRPCServer is the implementation of our GRPC service.
+type testGRPCServer struct {
+ Impl testInterface
+ broker *GRPCBroker
+}
+
+func (s *testGRPCServer) Double(
+ ctx context.Context,
+ req *grpctest.TestRequest) (*grpctest.TestResponse, error) {
+ return &grpctest.TestResponse{
+ Output: int32(s.Impl.Double(int(req.Input))),
+ }, nil
+}
+
+func (s *testGRPCServer) PrintKV(
+ ctx context.Context,
+ req *grpctest.PrintKVRequest) (*grpctest.PrintKVResponse, error) {
+ var v interface{}
+ switch rv := req.Value.(type) {
+ case *grpctest.PrintKVRequest_ValueString:
+ v = rv.ValueString
+
+ case *grpctest.PrintKVRequest_ValueInt:
+ v = rv.ValueInt
+
+ default:
+ panic(fmt.Sprintf("unknown value: %#v", req.Value))
+ }
+
+ s.Impl.PrintKV(req.Key, v)
+ return &grpctest.PrintKVResponse{}, nil
+}
+
+func (s *testGRPCServer) Bidirectional(ctx context.Context, req *grpctest.BidirectionalRequest) (*grpctest.BidirectionalResponse, error) {
+ conn, err := s.broker.Dial(req.Id)
+ if err != nil {
+ return nil, err
+ }
+
+ pingPongClient := grpctest.NewPingPongClient(conn)
+ resp, err := pingPongClient.Ping(ctx, &grpctest.PingRequest{})
+ if err != nil {
+ return nil, err
+ }
+ if resp.Msg != "pong" {
+ return nil, errors.New("Bad PingPong")
+ }
+
+ nextID := s.broker.NextId()
+ go s.broker.AcceptAndServe(nextID, func(opts []grpc.ServerOption) *grpc.Server {
+ s := grpc.NewServer(opts...)
+ grpctest.RegisterPingPongServer(s, &pingPongServer{})
+ return s
+ })
+
+ return &grpctest.BidirectionalResponse{
+ Id: nextID,
+ }, nil
+}
+
+type pingPongServer struct{}
+
+func (p *pingPongServer) Ping(ctx context.Context, req *grpctest.PingRequest) (*grpctest.PongResponse, error) {
+ return &grpctest.PongResponse{
+ Msg: "pong",
+ }, nil
+}
+
+func (s testGRPCServer) Stream(stream grpctest.Test_StreamServer) error {
+ for {
+ req, err := stream.Recv()
+ if err != nil {
+ if err != io.EOF {
+ return err
+ }
+ return nil
+ }
+
+ if err := stream.Send(&grpctest.TestResponse{Output: req.Input}); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// testGRPCClient is an implementation of TestInterface that communicates
+// over gRPC.
+type testGRPCClient struct {
+ Client grpctest.TestClient
+ broker *GRPCBroker
+}
+
+func (c *testGRPCClient) Double(v int) int {
+ resp, err := c.Client.Double(context.Background(), &grpctest.TestRequest{
+ Input: int32(v),
+ })
+ if err != nil {
+ panic(err)
+ }
+
+ return int(resp.Output)
+}
+
+func (c *testGRPCClient) PrintKV(key string, value interface{}) {
+ req := &grpctest.PrintKVRequest{Key: key}
+ switch v := value.(type) {
+ case string:
+ req.Value = &grpctest.PrintKVRequest_ValueString{
+ ValueString: v,
+ }
+
+ case int:
+ req.Value = &grpctest.PrintKVRequest_ValueInt{
+ ValueInt: int32(v),
+ }
+
+ default:
+ panic(fmt.Sprintf("unknown type: %T", value))
+ }
+
+ _, err := c.Client.PrintKV(context.Background(), req)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func (c *testGRPCClient) Bidirectional() error {
+ nextID := c.broker.NextId()
+ go c.broker.AcceptAndServe(nextID, func(opts []grpc.ServerOption) *grpc.Server {
+ s := grpc.NewServer(opts...)
+ grpctest.RegisterPingPongServer(s, &pingPongServer{})
+ return s
+ })
+
+ resp, err := c.Client.Bidirectional(context.Background(), &grpctest.BidirectionalRequest{
+ Id: nextID,
+ })
+ if err != nil {
+ return err
+ }
+
+ conn, err := c.broker.Dial(resp.Id)
+ if err != nil {
+ return err
+ }
+
+ pingPongClient := grpctest.NewPingPongClient(conn)
+ pResp, err := pingPongClient.Ping(context.Background(), &grpctest.PingRequest{})
+ if err != nil {
+ return err
+ }
+ if pResp.Msg != "pong" {
+ return errors.New("Bad PingPong")
+ }
+ return nil
+}
+
+// Stream sends a series of requests from [start, stop) using a bidirectional
+// streaming service, and returns the streamed responses.
+func (impl *testGRPCClient) Stream(start, stop int32) ([]int32, error) {
+ if stop <= start {
+ return nil, fmt.Errorf("invalid range [%d, %d)", start, stop)
+ }
+ streamClient, err := impl.Client.Stream(context.Background())
+ if err != nil {
+ return nil, err
+ }
+
+ var resp []int32
+ for i := start; i < stop; i++ {
+ if err := streamClient.Send(&grpctest.TestRequest{i}); err != nil {
+ return resp, err
+ }
+
+ out, err := streamClient.Recv()
+ if err != nil {
+ return resp, err
+ }
+
+ resp = append(resp, out.Output)
+ }
+
+ streamClient.CloseSend()
+
+ return resp, nil
+}
+
+func helperProcess(s ...string) *exec.Cmd {
+ cs := []string{"-test.run=TestHelperProcess", "--"}
+ cs = append(cs, s...)
+ env := []string{
+ "GO_WANT_HELPER_PROCESS=1",
+ "PLUGIN_MIN_PORT=10000",
+ "PLUGIN_MAX_PORT=25000",
+ }
+
+ cmd := exec.Command(os.Args[0], cs...)
+ cmd.Env = append(env, os.Environ()...)
+ return cmd
+}
+
+// This is not a real test. This is just a helper process kicked off by
+// tests.
+func TestHelperProcess(*testing.T) {
+ if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
+ return
+ }
+
+ defer os.Exit(0)
+
+ args := os.Args
+ for len(args) > 0 {
+ if args[0] == "--" {
+ args = args[1:]
+ break
+ }
+
+ args = args[1:]
+ }
+
+ if len(args) == 0 {
+ fmt.Fprintf(os.Stderr, "No command\n")
+ os.Exit(2)
+ }
+
+ // override testPluginMap with one that uses
+ // hclog logger on its implementation
+ pluginLogger := hclog.New(&hclog.LoggerOptions{
+ Level: hclog.Trace,
+ Output: os.Stderr,
+ JSONFormat: true,
+ })
+
+ testPlugin := &testInterfaceImpl{
+ logger: pluginLogger,
+ }
+
+ testPluginMap := map[string]Plugin{
+ "test": &testInterfacePlugin{Impl: testPlugin},
+ }
+
+ cmd, args := args[0], args[1:]
+ switch cmd {
+ case "bad-version":
+ // If we have an arg, we write there on start
+ if len(args) > 0 {
+ path := args[0]
+ err := ioutil.WriteFile(path, []byte("foo"), 0644)
+ if err != nil {
+ panic(err)
+ }
+ }
+
+ fmt.Printf("%d|%d1|tcp|:1234\n", CoreProtocolVersion, testHandshake.ProtocolVersion)
+ <-make(chan int)
+ case "invalid-rpc-address":
+ fmt.Println("lolinvalid")
+ case "mock":
+ fmt.Printf("%d|%d|tcp|:1234\n", CoreProtocolVersion, testHandshake.ProtocolVersion)
+ <-make(chan int)
+ case "start-timeout":
+ time.Sleep(1 * time.Minute)
+ os.Exit(1)
+ case "stderr":
+ fmt.Printf("%d|%d|tcp|:1234\n", CoreProtocolVersion, testHandshake.ProtocolVersion)
+ os.Stderr.WriteString("HELLO\n")
+ os.Stderr.WriteString("WORLD\n")
+ case "stderr-json":
+ // write values that might be JSON, but aren't KVs
+ fmt.Printf("%d|%d|tcp|:1234\n", CoreProtocolVersion, testHandshake.ProtocolVersion)
+ os.Stderr.WriteString("[\"HELLO\"]\n")
+ os.Stderr.WriteString("12345\n")
+ case "stdin":
+ fmt.Printf("%d|%d|tcp|:1234\n", CoreProtocolVersion, testHandshake.ProtocolVersion)
+ data := make([]byte, 5)
+ if _, err := os.Stdin.Read(data); err != nil {
+ log.Printf("stdin read error: %s", err)
+ os.Exit(100)
+ }
+
+ if string(data) == "hello" {
+ os.Exit(0)
+ }
+
+ os.Exit(1)
+ case "cleanup":
+ // Create a defer to write the file. This tests that we get cleaned
+ // up properly versus just calling os.Exit
+ path := args[0]
+ defer func() {
+ err := ioutil.WriteFile(path, []byte("foo"), 0644)
+ if err != nil {
+ panic(err)
+ }
+ }()
+
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Exit
+ return
+ case "test-grpc":
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ GRPCServer: DefaultGRPCServer,
+ })
+
+ // Shouldn't reach here but make sure we exit anyways
+ os.Exit(0)
+ case "test-grpc-tls":
+ // Serve!
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ GRPCServer: DefaultGRPCServer,
+ TLSProvider: helperTLSProvider,
+ })
+
+ // Shouldn't reach here but make sure we exit anyways
+ os.Exit(0)
+ case "test-interface":
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Shouldn't reach here but make sure we exit anyways
+ os.Exit(0)
+ case "test-interface-logger-netrpc":
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+ // Shouldn't reach here but make sure we exit anyways
+ os.Exit(0)
+ case "test-interface-logger-grpc":
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ GRPCServer: DefaultGRPCServer,
+ })
+ // Shouldn't reach here but make sure we exit anyways
+ os.Exit(0)
+ case "test-interface-daemon":
+ // Serve!
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ })
+
+ // Shouldn't reach here but make sure we exit anyways
+ os.Exit(0)
+ case "test-interface-tls":
+ // Serve!
+ Serve(&ServeConfig{
+ HandshakeConfig: testHandshake,
+ Plugins: testPluginMap,
+ TLSProvider: helperTLSProvider,
+ })
+
+ // Shouldn't reach here but make sure we exit anyways
+ os.Exit(0)
+ default:
+ fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd)
+ os.Exit(2)
+ }
+}
+
+func helperTLSProvider() (*tls.Config, error) {
+ serverCert, err := tls.X509KeyPair([]byte(TestClusterServerCert), []byte(TestClusterServerKey))
+ if err != nil {
+ return nil, err
+ }
+
+ rootCAs := x509.NewCertPool()
+ rootCAs.AppendCertsFromPEM([]byte(TestClusterCACert))
+ tlsConfig := &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ RootCAs: rootCAs,
+ ClientCAs: rootCAs,
+ ClientAuth: tls.VerifyClientCertIfGiven,
+ ServerName: "127.0.0.1",
+ }
+ tlsConfig.BuildNameToCertificate()
+
+ return tlsConfig, nil
+}
+
+const (
+ TestClusterCACert = `-----BEGIN CERTIFICATE-----
+MIIDPjCCAiagAwIBAgIUfIKsF2VPT7sdFcKOHJH2Ii6K4MwwDQYJKoZIhvcNAQEL
+BQAwFjEUMBIGA1UEAxMLbXl2YXVsdC5jb20wIBcNMTYwNTAyMTYwNTQyWhgPMjA2
+NjA0MjAxNjA2MTJaMBYxFDASBgNVBAMTC215dmF1bHQuY29tMIIBIjANBgkqhkiG
+9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuOimEXawD2qBoLCFP3Skq5zi1XzzcMAJlfdS
+xz9hfymuJb+cN8rB91HOdU9wQCwVKnkUtGWxUnMp0tT0uAZj5NzhNfyinf0JGAbP
+67HDzVZhGBHlHTjPX0638yaiUx90cTnucX0N20SgCYct29dMSgcPl+W78D3Jw3xE
+JsHQPYS9ASe2eONxG09F/qNw7w/RO5/6WYoV2EmdarMMxq52pPe2chtNMQdSyOUb
+cCcIZyk4QVFZ1ZLl6jTnUPb+JoCx1uMxXvMek4NF/5IL0Wr9dw2gKXKVKoHDr6SY
+WrCONRw61A5Zwx1V+kn73YX3USRlkufQv/ih6/xThYDAXDC9cwIDAQABo4GBMH8w
+DgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFOuKvPiU
+G06iHkRXAOeMiUdBfHFyMB8GA1UdIwQYMBaAFOuKvPiUG06iHkRXAOeMiUdBfHFy
+MBwGA1UdEQQVMBOCC215dmF1bHQuY29thwR/AAABMA0GCSqGSIb3DQEBCwUAA4IB
+AQBcN/UdAMzc7UjRdnIpZvO+5keBGhL/vjltnGM1dMWYHa60Y5oh7UIXF+P1RdNW
+n7g80lOyvkSR15/r1rDkqOK8/4oruXU31EcwGhDOC4hU6yMUy4ltV/nBoodHBXNh
+MfKiXeOstH1vdI6G0P6W93Bcww6RyV1KH6sT2dbETCw+iq2VN9CrruGIWzd67UT/
+spe/kYttr3UYVV3O9kqgffVVgVXg/JoRZ3J7Hy2UEXfh9UtWNanDlRuXaZgE9s/d
+CpA30CHpNXvKeyNeW2ktv+2nAbSpvNW+e6MecBCTBIoDSkgU8ShbrzmDKVwNN66Q
+5gn6KxUPBKHEtNzs5DgGM7nq
+-----END CERTIFICATE-----`
+
+ TestClusterServerCert = `-----BEGIN CERTIFICATE-----
+MIIDtzCCAp+gAwIBAgIUBLqh6ctGWVDUxFhxJX7m6S/bnrcwDQYJKoZIhvcNAQEL
+BQAwFjEUMBIGA1UEAxMLbXl2YXVsdC5jb20wIBcNMTYwNTAyMTYwOTI2WhgPMjA2
+NjA0MjAxNTA5NTZaMBsxGTAXBgNVBAMTEGNlcnQubXl2YXVsdC5jb20wggEiMA0G
+CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDY3gPB29kkdbu0mPO6J0efagQhSiXB
+9OyDuLf5sMk6CVDWVWal5hISkyBmw/lXgF7qC2XFKivpJOrcGQd5Ep9otBqyJLzI
+b0IWdXuPIrVnXDwcdWr86ybX2iC42zKWfbXgjzGijeAVpl0UJLKBj+fk5q6NvkRL
+5FUL6TRV7Krn9mrmnrV9J5IqV15pTd9W2aVJ6IqWvIPCACtZKulqWn4707uy2X2W
+1Stq/5qnp1pDshiGk1VPyxCwQ6yw3iEcgecbYo3vQfhWcv7Q8LpSIM9ZYpXu6OmF
++czqRZS9gERl+wipmmrN1MdYVrTuQem21C/PNZ4jo4XUk1SFx6JrcA+lAgMBAAGj
+gfUwgfIwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSe
+Cl9WV3BjGCwmS/KrDSLRjfwyqjAfBgNVHSMEGDAWgBTrirz4lBtOoh5EVwDnjIlH
+QXxxcjA7BggrBgEFBQcBAQQvMC0wKwYIKwYBBQUHMAKGH2h0dHA6Ly8xMjcuMC4w
+LjE6ODIwMC92MS9wa2kvY2EwIQYDVR0RBBowGIIQY2VydC5teXZhdWx0LmNvbYcE
+fwAAATAxBgNVHR8EKjAoMCagJKAihiBodHRwOi8vMTI3LjAuMC4xOjgyMDAvdjEv
+cGtpL2NybDANBgkqhkiG9w0BAQsFAAOCAQEAWGholPN8buDYwKbUiDavbzjsxUIX
+lU4MxEqOHw7CD3qIYIauPboLvB9EldBQwhgOOy607Yvdg3rtyYwyBFwPhHo/hK3Z
+6mn4hc6TF2V+AUdHBvGzp2dbYLeo8noVoWbQ/lBulggwlIHNNF6+a3kALqsqk1Ch
+f/hzsjFnDhAlNcYFgG8TgfE2lE/FckvejPqBffo7Q3I+wVAw0buqiz5QL81NOT+D
+Y2S9LLKLRaCsWo9wRU1Az4Rhd7vK5SEMh16jJ82GyEODWPvuxOTI1MnzfnbWyLYe
+TTp6YBjGMVf1I6NEcWNur7U17uIOiQjMZ9krNvoMJ1A/cxCoZ98QHgcIPg==
+-----END CERTIFICATE-----`
+
+ TestClusterServerKey = `-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEA2N4DwdvZJHW7tJjzuidHn2oEIUolwfTsg7i3+bDJOglQ1lVm
+peYSEpMgZsP5V4Be6gtlxSor6STq3BkHeRKfaLQasiS8yG9CFnV7jyK1Z1w8HHVq
+/Osm19oguNsyln214I8xoo3gFaZdFCSygY/n5Oaujb5ES+RVC+k0Veyq5/Zq5p61
+fSeSKldeaU3fVtmlSeiKlryDwgArWSrpalp+O9O7stl9ltUrav+ap6daQ7IYhpNV
+T8sQsEOssN4hHIHnG2KN70H4VnL+0PC6UiDPWWKV7ujphfnM6kWUvYBEZfsIqZpq
+zdTHWFa07kHpttQvzzWeI6OF1JNUhceia3APpQIDAQABAoIBAQCH3vEzr+3nreug
+RoPNCXcSJXXY9X+aeT0FeeGqClzIg7Wl03OwVOjVwl/2gqnhbIgK0oE8eiNwurR6
+mSPZcxV0oAJpwiKU4T/imlCDaReGXn86xUX2l82KRxthNdQH/VLKEmzij0jpx4Vh
+bWx5SBPdkbmjDKX1dmTiRYWIn/KjyNPvNvmtwdi8Qluhf4eJcNEUr2BtblnGOmfL
+FdSu+brPJozpoQ1QdDnbAQRgqnh7Shl0tT85whQi0uquqIj1gEOGVjmBvDDnL3GV
+WOENTKqsmIIoEzdZrql1pfmYTk7WNaD92bfpN128j8BF7RmAV4/DphH0pvK05y9m
+tmRhyHGxAoGBAOV2BBocsm6xup575VqmFN+EnIOiTn+haOvfdnVsyQHnth63fOQx
+PNtMpTPR1OMKGpJ13e2bV0IgcYRsRkScVkUtoa/17VIgqZXffnJJ0A/HT67uKBq3
+8o7RrtyK5N20otw0lZHyqOPhyCdpSsurDhNON1kPVJVYY4N1RiIxfut/AoGBAPHz
+HfsJ5ZkyELE9N/r4fce04lprxWH+mQGK0/PfjS9caXPhj/r5ZkVMvzWesF3mmnY8
+goE5S35TuTvV1+6rKGizwlCFAQlyXJiFpOryNWpLwCmDDSzLcm+sToAlML3tMgWU
+jM3dWHx3C93c3ft4rSWJaUYI9JbHsMzDW6Yh+GbbAoGBANIbKwxh5Hx5XwEJP2yu
+kIROYCYkMy6otHLujgBdmPyWl+suZjxoXWoMl2SIqR8vPD+Jj6mmyNJy9J6lqf3f
+DRuQ+fEuBZ1i7QWfvJ+XuN0JyovJ5Iz6jC58D1pAD+p2IX3y5FXcVQs8zVJRFjzB
+p0TEJOf2oqORaKWRd6ONoMKvAoGALKu6aVMWdQZtVov6/fdLIcgf0pn7Q3CCR2qe
+X3Ry2L+zKJYIw0mwvDLDSt8VqQCenB3n6nvtmFFU7ds5lvM67rnhsoQcAOaAehiS
+rl4xxoJd5Ewx7odRhZTGmZpEOYzFo4odxRSM9c30/u18fqV1Mm0AZtHYds4/sk6P
+aUj0V+kCgYBMpGrJk8RSez5g0XZ35HfpI4ENoWbiwB59FIpWsLl2LADEh29eC455
+t9Muq7MprBVBHQo11TMLLFxDIjkuMho/gcKgpYXCt0LfiNm8EZehvLJUXH+3WqUx
+we6ywrbFCs6LaxaOCtTiLsN+GbZCatITL0UJaeBmTAbiw0KQjUuZPQ==
+-----END RSA PRIVATE KEY-----`
+)
--- /dev/null
+package plugin
+
+import (
+ "time"
+)
+
+// pidAlive checks whether a pid is alive.
+func pidAlive(pid int) bool {
+ return _pidAlive(pid)
+}
+
+// pidWait blocks for a process to exit.
+func pidWait(pid int) error {
+ ticker := time.NewTicker(1 * time.Second)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ if !pidAlive(pid) {
+ break
+ }
+ }
+
+ return nil
+}
--- /dev/null
+// +build !windows
+
+package plugin
+
+import (
+ "os"
+ "syscall"
+)
+
+// _pidAlive tests whether a process is alive or not by sending it Signal 0,
+// since Go otherwise has no way to test this.
+func _pidAlive(pid int) bool {
+ proc, err := os.FindProcess(pid)
+ if err == nil {
+ err = proc.Signal(syscall.Signal(0))
+ }
+
+ return err == nil
+}
--- /dev/null
+package plugin
+
+import (
+ "syscall"
+)
+
+const (
+ // Weird name but matches the MSDN docs
+ exit_STILL_ACTIVE = 259
+
+ processDesiredAccess = syscall.STANDARD_RIGHTS_READ |
+ syscall.PROCESS_QUERY_INFORMATION |
+ syscall.SYNCHRONIZE
+)
+
+// _pidAlive tests whether a process is alive or not
+func _pidAlive(pid int) bool {
+ h, err := syscall.OpenProcess(processDesiredAccess, false, uint32(pid))
+ if err != nil {
+ return false
+ }
+
+ var ec uint32
+ if e := syscall.GetExitCodeProcess(h, &ec); e != nil {
+ return false
+ }
+
+ return ec == exit_STILL_ACTIVE
+}
--- /dev/null
+package plugin
+
+import (
+ "io"
+ "net"
+)
+
+// Protocol is an enum representing the types of protocols.
+type Protocol string
+
+const (
+ ProtocolInvalid Protocol = ""
+ ProtocolNetRPC Protocol = "netrpc"
+ ProtocolGRPC Protocol = "grpc"
+)
+
+// ServerProtocol is an interface that must be implemented for new plugin
+// protocols to be servers.
+type ServerProtocol interface {
+ // Init is called once to configure and initialize the protocol, but
+ // not start listening. This is the point at which all validation should
+ // be done and errors returned.
+ Init() error
+
+ // Config is extra configuration to be outputted to stdout. This will
+ // be automatically base64 encoded to ensure it can be parsed properly.
+ // This can be an empty string if additional configuration is not needed.
+ Config() string
+
+ // Serve is called to serve connections on the given listener. This should
+ // continue until the listener is closed.
+ Serve(net.Listener)
+}
+
+// ClientProtocol is an interface that must be implemented for new plugin
+// protocols to be clients.
+type ClientProtocol interface {
+ io.Closer
+
+ // Dispense dispenses a new instance of the plugin with the given name.
+ Dispense(string) (interface{}, error)
+
+ // Ping checks that the client connection is still healthy.
+ Ping() error
+}
--- /dev/null
+package plugin
+
+import (
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net"
+ "net/rpc"
+
+ "github.com/hashicorp/yamux"
+)
+
+// RPCClient connects to an RPCServer over net/rpc to dispense plugin types.
+type RPCClient struct {
+ broker *MuxBroker
+ control *rpc.Client
+ plugins map[string]Plugin
+
+ // These are the streams used for the various stdout/err overrides
+ stdout, stderr net.Conn
+}
+
+// newRPCClient creates a new RPCClient. The Client argument is expected
+// to be successfully started already with a lock held.
+func newRPCClient(c *Client) (*RPCClient, error) {
+ // Connect to the client
+ conn, err := net.Dial(c.address.Network(), c.address.String())
+ if err != nil {
+ return nil, err
+ }
+ if tcpConn, ok := conn.(*net.TCPConn); ok {
+ // Make sure to set keep alive so that the connection doesn't die
+ tcpConn.SetKeepAlive(true)
+ }
+
+ if c.config.TLSConfig != nil {
+ conn = tls.Client(conn, c.config.TLSConfig)
+ }
+
+ // Create the actual RPC client
+ result, err := NewRPCClient(conn, c.config.Plugins)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ // Begin the stream syncing so that stdin, out, err work properly
+ err = result.SyncStreams(
+ c.config.SyncStdout,
+ c.config.SyncStderr)
+ if err != nil {
+ result.Close()
+ return nil, err
+ }
+
+ return result, nil
+}
+
+// NewRPCClient creates a client from an already-open connection-like value.
+// Dial is typically used instead.
+func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) {
+ // Create the yamux client so we can multiplex
+ mux, err := yamux.Client(conn, nil)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ // Connect to the control stream.
+ control, err := mux.Open()
+ if err != nil {
+ mux.Close()
+ return nil, err
+ }
+
+ // Connect stdout, stderr streams
+ stdstream := make([]net.Conn, 2)
+ for i, _ := range stdstream {
+ stdstream[i], err = mux.Open()
+ if err != nil {
+ mux.Close()
+ return nil, err
+ }
+ }
+
+ // Create the broker and start it up
+ broker := newMuxBroker(mux)
+ go broker.Run()
+
+ // Build the client using our broker and control channel.
+ return &RPCClient{
+ broker: broker,
+ control: rpc.NewClient(control),
+ plugins: plugins,
+ stdout: stdstream[0],
+ stderr: stdstream[1],
+ }, nil
+}
+
+// SyncStreams should be called to enable syncing of stdout,
+// stderr with the plugin.
+//
+// This will return immediately and the syncing will continue to happen
+// in the background. You do not need to launch this in a goroutine itself.
+//
+// This should never be called multiple times.
+func (c *RPCClient) SyncStreams(stdout io.Writer, stderr io.Writer) error {
+ go copyStream("stdout", stdout, c.stdout)
+ go copyStream("stderr", stderr, c.stderr)
+ return nil
+}
+
+// Close closes the connection. The client is no longer usable after this
+// is called.
+func (c *RPCClient) Close() error {
+ // Call the control channel and ask it to gracefully exit. If this
+ // errors, then we save it so that we always return an error but we
+ // want to try to close the other channels anyways.
+ var empty struct{}
+ returnErr := c.control.Call("Control.Quit", true, &empty)
+
+ // Close the other streams we have
+ if err := c.control.Close(); err != nil {
+ return err
+ }
+ if err := c.stdout.Close(); err != nil {
+ return err
+ }
+ if err := c.stderr.Close(); err != nil {
+ return err
+ }
+ if err := c.broker.Close(); err != nil {
+ return err
+ }
+
+ // Return back the error we got from Control.Quit. This is very important
+ // since we MUST return non-nil error if this fails so that Client.Kill
+ // will properly try a process.Kill.
+ return returnErr
+}
+
+func (c *RPCClient) Dispense(name string) (interface{}, error) {
+ p, ok := c.plugins[name]
+ if !ok {
+ return nil, fmt.Errorf("unknown plugin type: %s", name)
+ }
+
+ var id uint32
+ if err := c.control.Call(
+ "Dispenser.Dispense", name, &id); err != nil {
+ return nil, err
+ }
+
+ conn, err := c.broker.Dial(id)
+ if err != nil {
+ return nil, err
+ }
+
+ return p.Client(c.broker, rpc.NewClient(conn))
+}
+
+// Ping pings the connection to ensure it is still alive.
+//
+// The error from the RPC call is returned exactly if you want to inspect
+// it for further error analysis. Any error returned from here would indicate
+// that the connection to the plugin is not healthy.
+func (c *RPCClient) Ping() error {
+ var empty struct{}
+ return c.control.Call("Control.Ping", true, &empty)
+}
--- /dev/null
+package plugin
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ hclog "github.com/hashicorp/go-hclog"
+)
+
+func TestClient_App(t *testing.T) {
+ pluginLogger := hclog.New(&hclog.LoggerOptions{
+ Level: hclog.Trace,
+ Output: os.Stderr,
+ JSONFormat: true,
+ })
+
+ testPlugin := &testInterfaceImpl{
+ logger: pluginLogger,
+ }
+
+ client, _ := TestPluginRPCConn(t, map[string]Plugin{
+ "test": &testInterfacePlugin{Impl: testPlugin},
+ }, nil)
+ defer client.Close()
+
+ raw, err := client.Dispense("test")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ impl, ok := raw.(testInterface)
+ if !ok {
+ t.Fatalf("bad: %#v", raw)
+ }
+
+ result := impl.Double(21)
+ if result != 42 {
+ t.Fatalf("bad: %#v", result)
+ }
+}
+
+func TestClient_syncStreams(t *testing.T) {
+ // Create streams for the server that we can talk to
+ stdout_r, stdout_w := io.Pipe()
+ stderr_r, stderr_w := io.Pipe()
+
+ client, _ := TestPluginRPCConn(t, map[string]Plugin{}, &TestOptions{
+ ServerStdout: stdout_r,
+ ServerStderr: stderr_r,
+ })
+
+ // Start the data copying
+ var stdout_out, stderr_out safeBuffer
+ stdout := &safeBuffer{
+ b: bytes.NewBufferString("stdouttest"),
+ }
+ stderr := &safeBuffer{
+ b: bytes.NewBufferString("stderrtest"),
+ }
+ go client.SyncStreams(&stdout_out, &stderr_out)
+ go io.Copy(stdout_w, stdout)
+ go io.Copy(stderr_w, stderr)
+
+ // Unfortunately I can't think of a better way to make sure all the
+ // copies above go through so let's just exit.
+ time.Sleep(100 * time.Millisecond)
+
+ // Close everything, and lets test the result
+ client.Close()
+ stdout_w.Close()
+ stderr_w.Close()
+
+ if v := stdout_out.String(); v != "stdouttest" {
+ t.Fatalf("bad: %q", v)
+ }
+ if v := stderr_out.String(); v != "stderrtest" {
+ t.Fatalf("bad: %q", v)
+ }
+}
+
+type safeBuffer struct {
+ sync.Mutex
+ b *bytes.Buffer
+}
+
+func (s *safeBuffer) Write(p []byte) (n int, err error) {
+ s.Lock()
+ defer s.Unlock()
+ if s.b == nil {
+ s.b = new(bytes.Buffer)
+ }
+ return s.b.Write(p)
+}
+
+func (s *safeBuffer) Read(p []byte) (n int, err error) {
+ s.Lock()
+ defer s.Unlock()
+ if s.b == nil {
+ s.b = new(bytes.Buffer)
+ }
+ return s.b.Read(p)
+}
+
+func (s *safeBuffer) String() string {
+ s.Lock()
+ defer s.Unlock()
+ if s.b == nil {
+ s.b = new(bytes.Buffer)
+ }
+ return s.b.String()
+}
--- /dev/null
+package plugin
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/rpc"
+ "sync"
+
+ "github.com/hashicorp/yamux"
+)
+
+// RPCServer listens for network connections and then dispenses interface
+// implementations over net/rpc.
+//
+// After setting the fields below, they shouldn't be read again directly
+// from the structure which may be reading/writing them concurrently.
+type RPCServer struct {
+ Plugins map[string]Plugin
+
+ // Stdout, Stderr are what this server will use instead of the
+ // normal stdin/out/err. This is because due to the multi-process nature
+ // of our plugin system, we can't use the normal process values so we
+ // make our own custom one we pipe across.
+ Stdout io.Reader
+ Stderr io.Reader
+
+ // DoneCh should be set to a non-nil channel that will be closed
+ // when the control requests the RPC server to end.
+ DoneCh chan<- struct{}
+
+ lock sync.Mutex
+}
+
+// ServerProtocol impl.
+func (s *RPCServer) Init() error { return nil }
+
+// ServerProtocol impl.
+func (s *RPCServer) Config() string { return "" }
+
+// ServerProtocol impl.
+func (s *RPCServer) Serve(lis net.Listener) {
+ for {
+ conn, err := lis.Accept()
+ if err != nil {
+ log.Printf("[ERR] plugin: plugin server: %s", err)
+ return
+ }
+
+ go s.ServeConn(conn)
+ }
+}
+
+// ServeConn runs a single connection.
+//
+// ServeConn blocks, serving the connection until the client hangs up.
+func (s *RPCServer) ServeConn(conn io.ReadWriteCloser) {
+ // First create the yamux server to wrap this connection
+ mux, err := yamux.Server(conn, nil)
+ if err != nil {
+ conn.Close()
+ log.Printf("[ERR] plugin: error creating yamux server: %s", err)
+ return
+ }
+
+ // Accept the control connection
+ control, err := mux.Accept()
+ if err != nil {
+ mux.Close()
+ if err != io.EOF {
+ log.Printf("[ERR] plugin: error accepting control connection: %s", err)
+ }
+
+ return
+ }
+
+ // Connect the stdstreams (in, out, err)
+ stdstream := make([]net.Conn, 2)
+ for i, _ := range stdstream {
+ stdstream[i], err = mux.Accept()
+ if err != nil {
+ mux.Close()
+ log.Printf("[ERR] plugin: accepting stream %d: %s", i, err)
+ return
+ }
+ }
+
+ // Copy std streams out to the proper place
+ go copyStream("stdout", stdstream[0], s.Stdout)
+ go copyStream("stderr", stdstream[1], s.Stderr)
+
+ // Create the broker and start it up
+ broker := newMuxBroker(mux)
+ go broker.Run()
+
+ // Use the control connection to build the dispenser and serve the
+ // connection.
+ server := rpc.NewServer()
+ server.RegisterName("Control", &controlServer{
+ server: s,
+ })
+ server.RegisterName("Dispenser", &dispenseServer{
+ broker: broker,
+ plugins: s.Plugins,
+ })
+ server.ServeConn(control)
+}
+
+// done is called internally by the control server to trigger the
+// doneCh to close which is listened to by the main process to cleanly
+// exit.
+func (s *RPCServer) done() {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if s.DoneCh != nil {
+ close(s.DoneCh)
+ s.DoneCh = nil
+ }
+}
+
+// dispenseServer dispenses variousinterface implementations for Terraform.
+type controlServer struct {
+ server *RPCServer
+}
+
+// Ping can be called to verify the connection (and likely the binary)
+// is still alive to a plugin.
+func (c *controlServer) Ping(
+ null bool, response *struct{}) error {
+ *response = struct{}{}
+ return nil
+}
+
+func (c *controlServer) Quit(
+ null bool, response *struct{}) error {
+ // End the server
+ c.server.done()
+
+ // Always return true
+ *response = struct{}{}
+
+ return nil
+}
+
+// dispenseServer dispenses variousinterface implementations for Terraform.
+type dispenseServer struct {
+ broker *MuxBroker
+ plugins map[string]Plugin
+}
+
+func (d *dispenseServer) Dispense(
+ name string, response *uint32) error {
+ // Find the function to create this implementation
+ p, ok := d.plugins[name]
+ if !ok {
+ return fmt.Errorf("unknown plugin type: %s", name)
+ }
+
+ // Create the implementation first so we know if there is an error.
+ impl, err := p.Server(d.broker)
+ if err != nil {
+ // We turn the error into an errors error so that it works across RPC
+ return errors.New(err.Error())
+ }
+
+ // Reserve an ID for our implementation
+ id := d.broker.NextId()
+ *response = id
+
+ // Run the rest in a goroutine since it can only happen once this RPC
+ // call returns. We wait for a connection for the plugin implementation
+ // and serve it.
+ go func() {
+ conn, err := d.broker.Accept(id)
+ if err != nil {
+ log.Printf("[ERR] go-plugin: plugin dispense error: %s: %s", name, err)
+ return
+ }
+
+ serve(conn, "Plugin", impl)
+ }()
+
+ return nil
+}
+
+func serve(conn io.ReadWriteCloser, name string, v interface{}) {
+ server := rpc.NewServer()
+ if err := server.RegisterName(name, v); err != nil {
+ log.Printf("[ERR] go-plugin: plugin dispense error: %s", err)
+ return
+ }
+
+ server.ServeConn(conn)
+}
--- /dev/null
+package plugin
+
+import (
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net"
+ "os"
+ "os/signal"
+ "runtime"
+ "strconv"
+ "sync/atomic"
+
+ "github.com/hashicorp/go-hclog"
+
+ "google.golang.org/grpc"
+)
+
+// CoreProtocolVersion is the ProtocolVersion of the plugin system itself.
+// We will increment this whenever we change any protocol behavior. This
+// will invalidate any prior plugins but will at least allow us to iterate
+// on the core in a safe way. We will do our best to do this very
+// infrequently.
+const CoreProtocolVersion = 1
+
+// HandshakeConfig is the configuration used by client and servers to
+// handshake before starting a plugin connection. This is embedded by
+// both ServeConfig and ClientConfig.
+//
+// In practice, the plugin host creates a HandshakeConfig that is exported
+// and plugins then can easily consume it.
+type HandshakeConfig struct {
+ // ProtocolVersion is the version that clients must match on to
+ // agree they can communicate. This should match the ProtocolVersion
+ // set on ClientConfig when using a plugin.
+ ProtocolVersion uint
+
+ // MagicCookieKey and value are used as a very basic verification
+ // that a plugin is intended to be launched. This is not a security
+ // measure, just a UX feature. If the magic cookie doesn't match,
+ // we show human-friendly output.
+ MagicCookieKey string
+ MagicCookieValue string
+}
+
+// ServeConfig configures what sorts of plugins are served.
+type ServeConfig struct {
+ // HandshakeConfig is the configuration that must match clients.
+ HandshakeConfig
+
+ // TLSProvider is a function that returns a configured tls.Config.
+ TLSProvider func() (*tls.Config, error)
+
+ // Plugins are the plugins that are served.
+ Plugins map[string]Plugin
+
+ // GRPCServer should be non-nil to enable serving the plugins over
+ // gRPC. This is a function to create the server when needed with the
+ // given server options. The server options populated by go-plugin will
+ // be for TLS if set. You may modify the input slice.
+ //
+ // Note that the grpc.Server will automatically be registered with
+ // the gRPC health checking service. This is not optional since go-plugin
+ // relies on this to implement Ping().
+ GRPCServer func([]grpc.ServerOption) *grpc.Server
+
+ // Logger is used to pass a logger into the server. If none is provided the
+ // server will create a default logger.
+ Logger hclog.Logger
+}
+
+// Protocol returns the protocol that this server should speak.
+func (c *ServeConfig) Protocol() Protocol {
+ result := ProtocolNetRPC
+ if c.GRPCServer != nil {
+ result = ProtocolGRPC
+ }
+
+ return result
+}
+
+// Serve serves the plugins given by ServeConfig.
+//
+// Serve doesn't return until the plugin is done being executed. Any
+// errors will be outputted to os.Stderr.
+//
+// This is the method that plugins should call in their main() functions.
+func Serve(opts *ServeConfig) {
+ // Validate the handshake config
+ if opts.MagicCookieKey == "" || opts.MagicCookieValue == "" {
+ fmt.Fprintf(os.Stderr,
+ "Misconfigured ServeConfig given to serve this plugin: no magic cookie\n"+
+ "key or value was set. Please notify the plugin author and report\n"+
+ "this as a bug.\n")
+ os.Exit(1)
+ }
+
+ // First check the cookie
+ if os.Getenv(opts.MagicCookieKey) != opts.MagicCookieValue {
+ fmt.Fprintf(os.Stderr,
+ "This binary is a plugin. These are not meant to be executed directly.\n"+
+ "Please execute the program that consumes these plugins, which will\n"+
+ "load any plugins automatically\n")
+ os.Exit(1)
+ }
+
+ // Logging goes to the original stderr
+ log.SetOutput(os.Stderr)
+
+ logger := opts.Logger
+ if logger == nil {
+ // internal logger to os.Stderr
+ logger = hclog.New(&hclog.LoggerOptions{
+ Level: hclog.Trace,
+ Output: os.Stderr,
+ JSONFormat: true,
+ })
+ }
+
+ // Create our new stdout, stderr files. These will override our built-in
+ // stdout/stderr so that it works across the stream boundary.
+ stdout_r, stdout_w, err := os.Pipe()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err)
+ os.Exit(1)
+ }
+ stderr_r, stderr_w, err := os.Pipe()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err)
+ os.Exit(1)
+ }
+
+ // Register a listener so we can accept a connection
+ listener, err := serverListener()
+ if err != nil {
+ logger.Error("plugin init error", "error", err)
+ return
+ }
+
+ // Close the listener on return. We wrap this in a func() on purpose
+ // because the "listener" reference may change to TLS.
+ defer func() {
+ listener.Close()
+ }()
+
+ var tlsConfig *tls.Config
+ if opts.TLSProvider != nil {
+ tlsConfig, err = opts.TLSProvider()
+ if err != nil {
+ logger.Error("plugin tls init", "error", err)
+ return
+ }
+ }
+
+ // Create the channel to tell us when we're done
+ doneCh := make(chan struct{})
+
+ // Build the server type
+ var server ServerProtocol
+ switch opts.Protocol() {
+ case ProtocolNetRPC:
+ // If we have a TLS configuration then we wrap the listener
+ // ourselves and do it at that level.
+ if tlsConfig != nil {
+ listener = tls.NewListener(listener, tlsConfig)
+ }
+
+ // Create the RPC server to dispense
+ server = &RPCServer{
+ Plugins: opts.Plugins,
+ Stdout: stdout_r,
+ Stderr: stderr_r,
+ DoneCh: doneCh,
+ }
+
+ case ProtocolGRPC:
+ // Create the gRPC server
+ server = &GRPCServer{
+ Plugins: opts.Plugins,
+ Server: opts.GRPCServer,
+ TLS: tlsConfig,
+ Stdout: stdout_r,
+ Stderr: stderr_r,
+ DoneCh: doneCh,
+ }
+
+ default:
+ panic("unknown server protocol: " + opts.Protocol())
+ }
+
+ // Initialize the servers
+ if err := server.Init(); err != nil {
+ logger.Error("protocol init", "error", err)
+ return
+ }
+
+ // Build the extra configuration
+ extra := ""
+ if v := server.Config(); v != "" {
+ extra = base64.StdEncoding.EncodeToString([]byte(v))
+ }
+ if extra != "" {
+ extra = "|" + extra
+ }
+
+ logger.Debug("plugin address", "network", listener.Addr().Network(), "address", listener.Addr().String())
+
+ // Output the address and service name to stdout so that core can bring it up.
+ fmt.Printf("%d|%d|%s|%s|%s%s\n",
+ CoreProtocolVersion,
+ opts.ProtocolVersion,
+ listener.Addr().Network(),
+ listener.Addr().String(),
+ opts.Protocol(),
+ extra)
+ os.Stdout.Sync()
+
+ // Eat the interrupts
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, os.Interrupt)
+ go func() {
+ var count int32 = 0
+ for {
+ <-ch
+ newCount := atomic.AddInt32(&count, 1)
+ logger.Debug("plugin received interrupt signal, ignoring", "count", newCount)
+ }
+ }()
+
+ // Set our new out, err
+ os.Stdout = stdout_w
+ os.Stderr = stderr_w
+
+ // Accept connections and wait for completion
+ go server.Serve(listener)
+ <-doneCh
+}
+
+func serverListener() (net.Listener, error) {
+ if runtime.GOOS == "windows" {
+ return serverListener_tcp()
+ }
+
+ return serverListener_unix()
+}
+
+func serverListener_tcp() (net.Listener, error) {
+ minPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MIN_PORT"), 10, 32)
+ if err != nil {
+ return nil, err
+ }
+
+ maxPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MAX_PORT"), 10, 32)
+ if err != nil {
+ return nil, err
+ }
+
+ for port := minPort; port <= maxPort; port++ {
+ address := fmt.Sprintf("127.0.0.1:%d", port)
+ listener, err := net.Listen("tcp", address)
+ if err == nil {
+ return listener, nil
+ }
+ }
+
+ return nil, errors.New("Couldn't bind plugin TCP listener")
+}
+
+func serverListener_unix() (net.Listener, error) {
+ tf, err := ioutil.TempFile("", "plugin")
+ if err != nil {
+ return nil, err
+ }
+ path := tf.Name()
+
+ // Close the file and remove it because it has to not exist for
+ // the domain socket.
+ if err := tf.Close(); err != nil {
+ return nil, err
+ }
+ if err := os.Remove(path); err != nil {
+ return nil, err
+ }
+
+ l, err := net.Listen("unix", path)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wrap the listener in rmListener so that the Unix domain socket file
+ // is removed on close.
+ return &rmListener{
+ Listener: l,
+ Path: path,
+ }, nil
+}
+
+// rmListener is an implementation of net.Listener that forwards most
+// calls to the listener but also removes a file as part of the close. We
+// use this to cleanup the unix domain socket on close.
+type rmListener struct {
+ net.Listener
+ Path string
+}
+
+func (l *rmListener) Close() error {
+ // Close the listener itself
+ if err := l.Listener.Close(); err != nil {
+ return err
+ }
+
+ // Remove the file
+ return os.Remove(l.Path)
+}
--- /dev/null
+package plugin
+
+import (
+ "fmt"
+ "os"
+)
+
+// ServeMuxMap is the type that is used to configure ServeMux
+type ServeMuxMap map[string]*ServeConfig
+
+// ServeMux is like Serve, but serves multiple types of plugins determined
+// by the argument given on the command-line.
+//
+// This command doesn't return until the plugin is done being executed. Any
+// errors are logged or output to stderr.
+func ServeMux(m ServeMuxMap) {
+ if len(os.Args) != 2 {
+ fmt.Fprintf(os.Stderr,
+ "Invoked improperly. This is an internal command that shouldn't\n"+
+ "be manually invoked.\n")
+ os.Exit(1)
+ }
+
+ opts, ok := m[os.Args[1]]
+ if !ok {
+ fmt.Fprintf(os.Stderr, "Unknown plugin: %s\n", os.Args[1])
+ os.Exit(1)
+ }
+
+ Serve(opts)
+}
--- /dev/null
+package plugin
+
+import (
+ "io/ioutil"
+ "net"
+ "os"
+ "testing"
+)
+
+func TestRmListener_impl(t *testing.T) {
+ var _ net.Listener = new(rmListener)
+}
+
+func TestRmListener(t *testing.T) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ tf, err := ioutil.TempFile("", "plugin")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ path := tf.Name()
+
+ // Close the file
+ if err := tf.Close(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Create the listener and test close
+ rmL := &rmListener{
+ Listener: l,
+ Path: path,
+ }
+ if err := rmL.Close(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // File should be goe
+ if _, err := os.Stat(path); err == nil || !os.IsNotExist(err) {
+ t.Fatalf("err: %s", err)
+ }
+}
--- /dev/null
+package plugin
+
+import (
+ "io"
+ "log"
+)
+
+func copyStream(name string, dst io.Writer, src io.Reader) {
+ if src == nil {
+ panic(name + ": src is nil")
+ }
+ if dst == nil {
+ panic(name + ": dst is nil")
+ }
+ if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
+ log.Printf("[ERR] plugin: stream copy '%s' error: %s", name, err)
+ }
+}
--- /dev/null
+package grpctest
+
+//go:generate protoc -I ./ ./test.proto --go_out=plugins=grpc:.
--- /dev/null
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: test.proto
+
+/*
+Package grpctest is a generated protocol buffer package.
+
+It is generated from these files:
+ test.proto
+
+It has these top-level messages:
+ TestRequest
+ TestResponse
+ PrintKVRequest
+ PrintKVResponse
+ BidirectionalRequest
+ BidirectionalResponse
+ PingRequest
+ PongResponse
+*/
+package grpctest
+
+import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+import (
+ context "golang.org/x/net/context"
+ grpc "google.golang.org/grpc"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
+
+type TestRequest struct {
+ Input int32 `protobuf:"varint,1,opt,name=Input" json:"Input,omitempty"`
+}
+
+func (m *TestRequest) Reset() { *m = TestRequest{} }
+func (m *TestRequest) String() string { return proto.CompactTextString(m) }
+func (*TestRequest) ProtoMessage() {}
+func (*TestRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+func (m *TestRequest) GetInput() int32 {
+ if m != nil {
+ return m.Input
+ }
+ return 0
+}
+
+type TestResponse struct {
+ Output int32 `protobuf:"varint,2,opt,name=Output" json:"Output,omitempty"`
+}
+
+func (m *TestResponse) Reset() { *m = TestResponse{} }
+func (m *TestResponse) String() string { return proto.CompactTextString(m) }
+func (*TestResponse) ProtoMessage() {}
+func (*TestResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
+
+func (m *TestResponse) GetOutput() int32 {
+ if m != nil {
+ return m.Output
+ }
+ return 0
+}
+
+type PrintKVRequest struct {
+ Key string `protobuf:"bytes,1,opt,name=Key" json:"Key,omitempty"`
+ // Types that are valid to be assigned to Value:
+ // *PrintKVRequest_ValueString
+ // *PrintKVRequest_ValueInt
+ Value isPrintKVRequest_Value `protobuf_oneof:"Value"`
+}
+
+func (m *PrintKVRequest) Reset() { *m = PrintKVRequest{} }
+func (m *PrintKVRequest) String() string { return proto.CompactTextString(m) }
+func (*PrintKVRequest) ProtoMessage() {}
+func (*PrintKVRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
+
+type isPrintKVRequest_Value interface{ isPrintKVRequest_Value() }
+
+type PrintKVRequest_ValueString struct {
+ ValueString string `protobuf:"bytes,2,opt,name=ValueString,oneof"`
+}
+type PrintKVRequest_ValueInt struct {
+ ValueInt int32 `protobuf:"varint,3,opt,name=ValueInt,oneof"`
+}
+
+func (*PrintKVRequest_ValueString) isPrintKVRequest_Value() {}
+func (*PrintKVRequest_ValueInt) isPrintKVRequest_Value() {}
+
+func (m *PrintKVRequest) GetValue() isPrintKVRequest_Value {
+ if m != nil {
+ return m.Value
+ }
+ return nil
+}
+
+func (m *PrintKVRequest) GetKey() string {
+ if m != nil {
+ return m.Key
+ }
+ return ""
+}
+
+func (m *PrintKVRequest) GetValueString() string {
+ if x, ok := m.GetValue().(*PrintKVRequest_ValueString); ok {
+ return x.ValueString
+ }
+ return ""
+}
+
+func (m *PrintKVRequest) GetValueInt() int32 {
+ if x, ok := m.GetValue().(*PrintKVRequest_ValueInt); ok {
+ return x.ValueInt
+ }
+ return 0
+}
+
+// XXX_OneofFuncs is for the internal use of the proto package.
+func (*PrintKVRequest) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
+ return _PrintKVRequest_OneofMarshaler, _PrintKVRequest_OneofUnmarshaler, _PrintKVRequest_OneofSizer, []interface{}{
+ (*PrintKVRequest_ValueString)(nil),
+ (*PrintKVRequest_ValueInt)(nil),
+ }
+}
+
+func _PrintKVRequest_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
+ m := msg.(*PrintKVRequest)
+ // Value
+ switch x := m.Value.(type) {
+ case *PrintKVRequest_ValueString:
+ b.EncodeVarint(2<<3 | proto.WireBytes)
+ b.EncodeStringBytes(x.ValueString)
+ case *PrintKVRequest_ValueInt:
+ b.EncodeVarint(3<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.ValueInt))
+ case nil:
+ default:
+ return fmt.Errorf("PrintKVRequest.Value has unexpected type %T", x)
+ }
+ return nil
+}
+
+func _PrintKVRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
+ m := msg.(*PrintKVRequest)
+ switch tag {
+ case 2: // Value.ValueString
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeStringBytes()
+ m.Value = &PrintKVRequest_ValueString{x}
+ return true, err
+ case 3: // Value.ValueInt
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Value = &PrintKVRequest_ValueInt{int32(x)}
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
+func _PrintKVRequest_OneofSizer(msg proto.Message) (n int) {
+ m := msg.(*PrintKVRequest)
+ // Value
+ switch x := m.Value.(type) {
+ case *PrintKVRequest_ValueString:
+ n += proto.SizeVarint(2<<3 | proto.WireBytes)
+ n += proto.SizeVarint(uint64(len(x.ValueString)))
+ n += len(x.ValueString)
+ case *PrintKVRequest_ValueInt:
+ n += proto.SizeVarint(3<<3 | proto.WireVarint)
+ n += proto.SizeVarint(uint64(x.ValueInt))
+ case nil:
+ default:
+ panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
+ }
+ return n
+}
+
+type PrintKVResponse struct {
+}
+
+func (m *PrintKVResponse) Reset() { *m = PrintKVResponse{} }
+func (m *PrintKVResponse) String() string { return proto.CompactTextString(m) }
+func (*PrintKVResponse) ProtoMessage() {}
+func (*PrintKVResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
+
+type BidirectionalRequest struct {
+ Id uint32 `protobuf:"varint,1,opt,name=id" json:"id,omitempty"`
+}
+
+func (m *BidirectionalRequest) Reset() { *m = BidirectionalRequest{} }
+func (m *BidirectionalRequest) String() string { return proto.CompactTextString(m) }
+func (*BidirectionalRequest) ProtoMessage() {}
+func (*BidirectionalRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
+
+func (m *BidirectionalRequest) GetId() uint32 {
+ if m != nil {
+ return m.Id
+ }
+ return 0
+}
+
+type BidirectionalResponse struct {
+ Id uint32 `protobuf:"varint,1,opt,name=id" json:"id,omitempty"`
+}
+
+func (m *BidirectionalResponse) Reset() { *m = BidirectionalResponse{} }
+func (m *BidirectionalResponse) String() string { return proto.CompactTextString(m) }
+func (*BidirectionalResponse) ProtoMessage() {}
+func (*BidirectionalResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
+
+func (m *BidirectionalResponse) GetId() uint32 {
+ if m != nil {
+ return m.Id
+ }
+ return 0
+}
+
+type PingRequest struct {
+}
+
+func (m *PingRequest) Reset() { *m = PingRequest{} }
+func (m *PingRequest) String() string { return proto.CompactTextString(m) }
+func (*PingRequest) ProtoMessage() {}
+func (*PingRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
+
+type PongResponse struct {
+ Msg string `protobuf:"bytes,1,opt,name=msg" json:"msg,omitempty"`
+}
+
+func (m *PongResponse) Reset() { *m = PongResponse{} }
+func (m *PongResponse) String() string { return proto.CompactTextString(m) }
+func (*PongResponse) ProtoMessage() {}
+func (*PongResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
+
+func (m *PongResponse) GetMsg() string {
+ if m != nil {
+ return m.Msg
+ }
+ return ""
+}
+
+func init() {
+ proto.RegisterType((*TestRequest)(nil), "grpctest.TestRequest")
+ proto.RegisterType((*TestResponse)(nil), "grpctest.TestResponse")
+ proto.RegisterType((*PrintKVRequest)(nil), "grpctest.PrintKVRequest")
+ proto.RegisterType((*PrintKVResponse)(nil), "grpctest.PrintKVResponse")
+ proto.RegisterType((*BidirectionalRequest)(nil), "grpctest.BidirectionalRequest")
+ proto.RegisterType((*BidirectionalResponse)(nil), "grpctest.BidirectionalResponse")
+ proto.RegisterType((*PingRequest)(nil), "grpctest.PingRequest")
+ proto.RegisterType((*PongResponse)(nil), "grpctest.PongResponse")
+}
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+const _ = grpc.SupportPackageIsVersion4
+
+// Client API for Test service
+
+type TestClient interface {
+ Double(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error)
+ PrintKV(ctx context.Context, in *PrintKVRequest, opts ...grpc.CallOption) (*PrintKVResponse, error)
+ Bidirectional(ctx context.Context, in *BidirectionalRequest, opts ...grpc.CallOption) (*BidirectionalResponse, error)
+ Stream(ctx context.Context, opts ...grpc.CallOption) (Test_StreamClient, error)
+}
+
+type testClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewTestClient(cc *grpc.ClientConn) TestClient {
+ return &testClient{cc}
+}
+
+func (c *testClient) Double(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) {
+ out := new(TestResponse)
+ err := grpc.Invoke(ctx, "/grpctest.Test/Double", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *testClient) PrintKV(ctx context.Context, in *PrintKVRequest, opts ...grpc.CallOption) (*PrintKVResponse, error) {
+ out := new(PrintKVResponse)
+ err := grpc.Invoke(ctx, "/grpctest.Test/PrintKV", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *testClient) Bidirectional(ctx context.Context, in *BidirectionalRequest, opts ...grpc.CallOption) (*BidirectionalResponse, error) {
+ out := new(BidirectionalResponse)
+ err := grpc.Invoke(ctx, "/grpctest.Test/Bidirectional", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (c *testClient) Stream(ctx context.Context, opts ...grpc.CallOption) (Test_StreamClient, error) {
+ stream, err := grpc.NewClientStream(ctx, &_Test_serviceDesc.Streams[0], c.cc, "/grpctest.Test/Stream", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &testStreamClient{stream}
+ return x, nil
+}
+
+type Test_StreamClient interface {
+ Send(*TestRequest) error
+ Recv() (*TestResponse, error)
+ grpc.ClientStream
+}
+
+type testStreamClient struct {
+ grpc.ClientStream
+}
+
+func (x *testStreamClient) Send(m *TestRequest) error {
+ return x.ClientStream.SendMsg(m)
+}
+
+func (x *testStreamClient) Recv() (*TestResponse, error) {
+ m := new(TestResponse)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// Server API for Test service
+
+type TestServer interface {
+ Double(context.Context, *TestRequest) (*TestResponse, error)
+ PrintKV(context.Context, *PrintKVRequest) (*PrintKVResponse, error)
+ Bidirectional(context.Context, *BidirectionalRequest) (*BidirectionalResponse, error)
+ Stream(Test_StreamServer) error
+}
+
+func RegisterTestServer(s *grpc.Server, srv TestServer) {
+ s.RegisterService(&_Test_serviceDesc, srv)
+}
+
+func _Test_Double_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(TestRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(TestServer).Double(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/grpctest.Test/Double",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(TestServer).Double(ctx, req.(*TestRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _Test_PrintKV_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(PrintKVRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(TestServer).PrintKV(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/grpctest.Test/PrintKV",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(TestServer).PrintKV(ctx, req.(*PrintKVRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _Test_Bidirectional_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(BidirectionalRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(TestServer).Bidirectional(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/grpctest.Test/Bidirectional",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(TestServer).Bidirectional(ctx, req.(*BidirectionalRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+func _Test_Stream_Handler(srv interface{}, stream grpc.ServerStream) error {
+ return srv.(TestServer).Stream(&testStreamServer{stream})
+}
+
+type Test_StreamServer interface {
+ Send(*TestResponse) error
+ Recv() (*TestRequest, error)
+ grpc.ServerStream
+}
+
+type testStreamServer struct {
+ grpc.ServerStream
+}
+
+func (x *testStreamServer) Send(m *TestResponse) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+func (x *testStreamServer) Recv() (*TestRequest, error) {
+ m := new(TestRequest)
+ if err := x.ServerStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+var _Test_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "grpctest.Test",
+ HandlerType: (*TestServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "Double",
+ Handler: _Test_Double_Handler,
+ },
+ {
+ MethodName: "PrintKV",
+ Handler: _Test_PrintKV_Handler,
+ },
+ {
+ MethodName: "Bidirectional",
+ Handler: _Test_Bidirectional_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{
+ {
+ StreamName: "Stream",
+ Handler: _Test_Stream_Handler,
+ ServerStreams: true,
+ ClientStreams: true,
+ },
+ },
+ Metadata: "test.proto",
+}
+
+// Client API for PingPong service
+
+type PingPongClient interface {
+ Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PongResponse, error)
+}
+
+type pingPongClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewPingPongClient(cc *grpc.ClientConn) PingPongClient {
+ return &pingPongClient{cc}
+}
+
+func (c *pingPongClient) Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PongResponse, error) {
+ out := new(PongResponse)
+ err := grpc.Invoke(ctx, "/grpctest.PingPong/Ping", in, out, c.cc, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// Server API for PingPong service
+
+type PingPongServer interface {
+ Ping(context.Context, *PingRequest) (*PongResponse, error)
+}
+
+func RegisterPingPongServer(s *grpc.Server, srv PingPongServer) {
+ s.RegisterService(&_PingPong_serviceDesc, srv)
+}
+
+func _PingPong_Ping_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(PingRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(PingPongServer).Ping(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/grpctest.PingPong/Ping",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(PingPongServer).Ping(ctx, req.(*PingRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+var _PingPong_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "grpctest.PingPong",
+ HandlerType: (*PingPongServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "Ping",
+ Handler: _PingPong_Ping_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{},
+ Metadata: "test.proto",
+}
+
+func init() { proto.RegisterFile("test.proto", fileDescriptor0) }
+
+var fileDescriptor0 = []byte{
+ // 355 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x92, 0xcd, 0x4e, 0xc2, 0x40,
+ 0x14, 0x85, 0xdb, 0x02, 0x05, 0x2e, 0x3f, 0xe2, 0x04, 0x08, 0x12, 0xa3, 0x64, 0x4c, 0x90, 0x15,
+ 0x31, 0xb8, 0x30, 0x2e, 0x4c, 0x0c, 0xba, 0x80, 0xb0, 0x90, 0x0c, 0x86, 0x3d, 0x3f, 0x93, 0x66,
+ 0x12, 0x3a, 0xad, 0x9d, 0xe9, 0xc2, 0x17, 0xf1, 0x79, 0xcd, 0x0c, 0x6d, 0x19, 0x08, 0x2e, 0xdc,
+ 0xdd, 0x73, 0x7b, 0x72, 0xe6, 0x9e, 0x2f, 0x05, 0x90, 0x54, 0xc8, 0x61, 0x18, 0x05, 0x32, 0x40,
+ 0x25, 0x2f, 0x0a, 0x37, 0x4a, 0xe3, 0x3b, 0xa8, 0x7c, 0x52, 0x21, 0x09, 0xfd, 0x8a, 0xa9, 0x90,
+ 0xa8, 0x09, 0x85, 0x29, 0x0f, 0x63, 0xd9, 0xb1, 0x7b, 0xf6, 0xa0, 0x40, 0xf6, 0x02, 0xf7, 0xa1,
+ 0xba, 0x37, 0x89, 0x30, 0xe0, 0x82, 0xa2, 0x36, 0xb8, 0x1f, 0xb1, 0x54, 0x36, 0x47, 0xdb, 0x12,
+ 0x85, 0x7d, 0xa8, 0xcf, 0x23, 0xc6, 0xe5, 0x6c, 0x99, 0xe6, 0x35, 0x20, 0x37, 0xa3, 0xdf, 0x3a,
+ 0xad, 0x4c, 0xd4, 0x88, 0x30, 0x54, 0x96, 0xab, 0x5d, 0x4c, 0x17, 0x32, 0x62, 0xdc, 0xd3, 0x01,
+ 0xe5, 0x89, 0x45, 0xcc, 0x25, 0xba, 0x86, 0x92, 0x96, 0x53, 0x2e, 0x3b, 0x39, 0xf5, 0xc2, 0xc4,
+ 0x22, 0xd9, 0x66, 0x5c, 0x84, 0x82, 0x9e, 0xf1, 0x25, 0x5c, 0x64, 0xcf, 0xed, 0x2f, 0xc3, 0x7d,
+ 0x68, 0x8e, 0xd9, 0x96, 0x45, 0x74, 0x23, 0x59, 0xc0, 0x57, 0xbb, 0xf4, 0x8e, 0x3a, 0x38, 0x6c,
+ 0xab, 0xcf, 0xa8, 0x11, 0x87, 0x6d, 0xf1, 0x3d, 0xb4, 0x4e, 0x7c, 0x49, 0xb5, 0x53, 0x63, 0x0d,
+ 0x2a, 0x73, 0xc6, 0xbd, 0x24, 0x07, 0xf7, 0xa0, 0x3a, 0x0f, 0x94, 0x4c, 0xec, 0x0d, 0xc8, 0xf9,
+ 0xc2, 0x4b, 0xfb, 0xf9, 0xc2, 0x1b, 0xfd, 0x38, 0x90, 0x57, 0xb0, 0xd0, 0x33, 0xb8, 0xef, 0x41,
+ 0xbc, 0xde, 0x51, 0xd4, 0x1a, 0xa6, 0xb8, 0x87, 0x06, 0xeb, 0x6e, 0xfb, 0x74, 0x9d, 0x74, 0xb0,
+ 0xd0, 0x2b, 0x14, 0x93, 0x62, 0xa8, 0x73, 0x30, 0x1d, 0xa3, 0xed, 0x5e, 0x9d, 0xf9, 0x92, 0x25,
+ 0x10, 0xa8, 0x1d, 0xf5, 0x43, 0x37, 0x07, 0xf7, 0x39, 0x40, 0xdd, 0xdb, 0x3f, 0xbf, 0x67, 0x99,
+ 0x2f, 0xe0, 0x2e, 0x64, 0x44, 0x57, 0xfe, 0xbf, 0x0b, 0x0d, 0xec, 0x07, 0x7b, 0xf4, 0x06, 0x25,
+ 0x45, 0x52, 0xe1, 0x43, 0x4f, 0x90, 0x57, 0xb3, 0x19, 0x64, 0x50, 0x36, 0x83, 0x4c, 0xda, 0xd8,
+ 0x5a, 0xbb, 0xfa, 0xff, 0x7d, 0xfc, 0x0d, 0x00, 0x00, 0xff, 0xff, 0xf0, 0x59, 0x20, 0xc7, 0xcd,
+ 0x02, 0x00, 0x00,
+}
--- /dev/null
+syntax = "proto3";
+package grpctest;
+
+message TestRequest {
+ int32 Input = 1;
+}
+
+message TestResponse {
+ int32 Output = 2;
+}
+
+message PrintKVRequest {
+ string Key = 1;
+ oneof Value {
+ string ValueString = 2;
+ int32 ValueInt = 3;
+ }
+}
+
+message PrintKVResponse {
+
+}
+
+message BidirectionalRequest {
+ uint32 id = 1;
+}
+
+message BidirectionalResponse {
+ uint32 id = 1;
+}
+
+service Test {
+ rpc Double(TestRequest) returns (TestResponse) {}
+ rpc PrintKV(PrintKVRequest) returns (PrintKVResponse) {}
+ rpc Bidirectional(BidirectionalRequest) returns (BidirectionalResponse) {}
+ rpc Stream(stream TestRequest) returns (stream TestResponse) {}
+}
+
+message PingRequest {
+}
+
+message PongResponse {
+ string msg = 1;
+}
+
+service PingPong {
+ rpc Ping(PingRequest) returns (PongResponse) {}
+}
--- /dev/null
+package plugin
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "net"
+ "net/rpc"
+
+ "github.com/mitchellh/go-testing-interface"
+ "google.golang.org/grpc"
+)
+
+// TestOptions allows specifying options that can affect the behavior of the
+// test functions
+type TestOptions struct {
+ //ServerStdout causes the given value to be used in place of a blank buffer
+ //for RPCServer's Stdout
+ ServerStdout io.ReadCloser
+
+ //ServerStderr causes the given value to be used in place of a blank buffer
+ //for RPCServer's Stderr
+ ServerStderr io.ReadCloser
+}
+
+// The testing file contains test helpers that you can use outside of
+// this package for making it easier to test plugins themselves.
+
+// TestConn is a helper function for returning a client and server
+// net.Conn connected to each other.
+func TestConn(t testing.T) (net.Conn, net.Conn) {
+ // Listen to any local port. This listener will be closed
+ // after a single connection is established.
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Start a goroutine to accept our client connection
+ var serverConn net.Conn
+ doneCh := make(chan struct{})
+ go func() {
+ defer close(doneCh)
+ defer l.Close()
+ var err error
+ serverConn, err = l.Accept()
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ }()
+
+ // Connect to the server
+ clientConn, err := net.Dial("tcp", l.Addr().String())
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Wait for the server side to acknowledge it has connected
+ <-doneCh
+
+ return clientConn, serverConn
+}
+
+// TestRPCConn returns a rpc client and server connected to each other.
+func TestRPCConn(t testing.T) (*rpc.Client, *rpc.Server) {
+ clientConn, serverConn := TestConn(t)
+
+ server := rpc.NewServer()
+ go server.ServeConn(serverConn)
+
+ client := rpc.NewClient(clientConn)
+ return client, server
+}
+
+// TestPluginRPCConn returns a plugin RPC client and server that are connected
+// together and configured.
+func TestPluginRPCConn(t testing.T, ps map[string]Plugin, opts *TestOptions) (*RPCClient, *RPCServer) {
+ // Create two net.Conns we can use to shuttle our control connection
+ clientConn, serverConn := TestConn(t)
+
+ // Start up the server
+ server := &RPCServer{Plugins: ps, Stdout: new(bytes.Buffer), Stderr: new(bytes.Buffer)}
+ if opts != nil {
+ if opts.ServerStdout != nil {
+ server.Stdout = opts.ServerStdout
+ }
+ if opts.ServerStderr != nil {
+ server.Stderr = opts.ServerStderr
+ }
+ }
+ go server.ServeConn(serverConn)
+
+ // Connect the client to the server
+ client, err := NewRPCClient(clientConn, ps)
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ return client, server
+}
+
+// TestGRPCConn returns a gRPC client conn and grpc server that are connected
+// together and configured. The register function is used to register services
+// prior to the Serve call. This is used to test gRPC connections.
+func TestGRPCConn(t testing.T, register func(*grpc.Server)) (*grpc.ClientConn, *grpc.Server) {
+ // Create a listener
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ server := grpc.NewServer()
+ register(server)
+ go server.Serve(l)
+
+ // Connect to the server
+ conn, err := grpc.Dial(
+ l.Addr().String(),
+ grpc.WithBlock(),
+ grpc.WithInsecure())
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Connection successful, close the listener
+ l.Close()
+
+ return conn, server
+}
+
+// TestPluginGRPCConn returns a plugin gRPC client and server that are connected
+// together and configured. This is used to test gRPC connections.
+func TestPluginGRPCConn(t testing.T, ps map[string]Plugin) (*GRPCClient, *GRPCServer) {
+ // Create a listener
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ // Start up the server
+ server := &GRPCServer{
+ Plugins: ps,
+ Server: DefaultGRPCServer,
+ Stdout: new(bytes.Buffer),
+ Stderr: new(bytes.Buffer),
+ }
+ if err := server.Init(); err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ go server.Serve(l)
+
+ // Connect to the server
+ conn, err := grpc.Dial(
+ l.Addr().String(),
+ grpc.WithBlock(),
+ grpc.WithInsecure())
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+
+ brokerGRPCClient := newGRPCBrokerClient(conn)
+ broker := newGRPCBroker(brokerGRPCClient, nil)
+ go broker.Run()
+ go brokerGRPCClient.StartStream()
+
+ // Create the client
+ client := &GRPCClient{
+ Conn: conn,
+ Plugins: ps,
+ broker: broker,
+ doneCtx: context.Background(),
+ }
+
+ return client, server
+}
--- /dev/null
+# Compiled Object files, Static and Dynamic libs (Shared Objects)
+*.o
+*.a
+*.so
+
+# Folders
+_obj
+_test
+
+# Architecture specific extensions/prefixes
+*.[568vq]
+[568vq].out
+
+*.cgo1.go
+*.cgo2.c
+_cgo_defun.c
+_cgo_gotypes.go
+_cgo_export.*
+
+_testmain.go
+
+*.exe
+*.test
--- /dev/null
+package lru
+
+import (
+ "fmt"
+ "sync"
+
+ "github.com/hashicorp/golang-lru/simplelru"
+)
+
+const (
+ // Default2QRecentRatio is the ratio of the 2Q cache dedicated
+ // to recently added entries that have only been accessed once.
+ Default2QRecentRatio = 0.25
+
+ // Default2QGhostEntries is the default ratio of ghost
+ // entries kept to track entries recently evicted
+ Default2QGhostEntries = 0.50
+)
+
+// TwoQueueCache is a thread-safe fixed size 2Q cache.
+// 2Q is an enhancement over the standard LRU cache
+// in that it tracks both frequently and recently used
+// entries separately. This avoids a burst in access to new
+// entries from evicting frequently used entries. It adds some
+// additional tracking overhead to the standard LRU cache, and is
+// computationally about 2x the cost, and adds some metadata over
+// head. The ARCCache is similar, but does not require setting any
+// parameters.
+type TwoQueueCache struct {
+ size int
+ recentSize int
+
+ recent simplelru.LRUCache
+ frequent simplelru.LRUCache
+ recentEvict simplelru.LRUCache
+ lock sync.RWMutex
+}
+
+// New2Q creates a new TwoQueueCache using the default
+// values for the parameters.
+func New2Q(size int) (*TwoQueueCache, error) {
+ return New2QParams(size, Default2QRecentRatio, Default2QGhostEntries)
+}
+
+// New2QParams creates a new TwoQueueCache using the provided
+// parameter values.
+func New2QParams(size int, recentRatio float64, ghostRatio float64) (*TwoQueueCache, error) {
+ if size <= 0 {
+ return nil, fmt.Errorf("invalid size")
+ }
+ if recentRatio < 0.0 || recentRatio > 1.0 {
+ return nil, fmt.Errorf("invalid recent ratio")
+ }
+ if ghostRatio < 0.0 || ghostRatio > 1.0 {
+ return nil, fmt.Errorf("invalid ghost ratio")
+ }
+
+ // Determine the sub-sizes
+ recentSize := int(float64(size) * recentRatio)
+ evictSize := int(float64(size) * ghostRatio)
+
+ // Allocate the LRUs
+ recent, err := simplelru.NewLRU(size, nil)
+ if err != nil {
+ return nil, err
+ }
+ frequent, err := simplelru.NewLRU(size, nil)
+ if err != nil {
+ return nil, err
+ }
+ recentEvict, err := simplelru.NewLRU(evictSize, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Initialize the cache
+ c := &TwoQueueCache{
+ size: size,
+ recentSize: recentSize,
+ recent: recent,
+ frequent: frequent,
+ recentEvict: recentEvict,
+ }
+ return c, nil
+}
+
+// Get looks up a key's value from the cache.
+func (c *TwoQueueCache) Get(key interface{}) (value interface{}, ok bool) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ // Check if this is a frequent value
+ if val, ok := c.frequent.Get(key); ok {
+ return val, ok
+ }
+
+ // If the value is contained in recent, then we
+ // promote it to frequent
+ if val, ok := c.recent.Peek(key); ok {
+ c.recent.Remove(key)
+ c.frequent.Add(key, val)
+ return val, ok
+ }
+
+ // No hit
+ return nil, false
+}
+
+// Add adds a value to the cache.
+func (c *TwoQueueCache) Add(key, value interface{}) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ // Check if the value is frequently used already,
+ // and just update the value
+ if c.frequent.Contains(key) {
+ c.frequent.Add(key, value)
+ return
+ }
+
+ // Check if the value is recently used, and promote
+ // the value into the frequent list
+ if c.recent.Contains(key) {
+ c.recent.Remove(key)
+ c.frequent.Add(key, value)
+ return
+ }
+
+ // If the value was recently evicted, add it to the
+ // frequently used list
+ if c.recentEvict.Contains(key) {
+ c.ensureSpace(true)
+ c.recentEvict.Remove(key)
+ c.frequent.Add(key, value)
+ return
+ }
+
+ // Add to the recently seen list
+ c.ensureSpace(false)
+ c.recent.Add(key, value)
+ return
+}
+
+// ensureSpace is used to ensure we have space in the cache
+func (c *TwoQueueCache) ensureSpace(recentEvict bool) {
+ // If we have space, nothing to do
+ recentLen := c.recent.Len()
+ freqLen := c.frequent.Len()
+ if recentLen+freqLen < c.size {
+ return
+ }
+
+ // If the recent buffer is larger than
+ // the target, evict from there
+ if recentLen > 0 && (recentLen > c.recentSize || (recentLen == c.recentSize && !recentEvict)) {
+ k, _, _ := c.recent.RemoveOldest()
+ c.recentEvict.Add(k, nil)
+ return
+ }
+
+ // Remove from the frequent list otherwise
+ c.frequent.RemoveOldest()
+}
+
+// Len returns the number of items in the cache.
+func (c *TwoQueueCache) Len() int {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.recent.Len() + c.frequent.Len()
+}
+
+// Keys returns a slice of the keys in the cache.
+// The frequently used keys are first in the returned slice.
+func (c *TwoQueueCache) Keys() []interface{} {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ k1 := c.frequent.Keys()
+ k2 := c.recent.Keys()
+ return append(k1, k2...)
+}
+
+// Remove removes the provided key from the cache.
+func (c *TwoQueueCache) Remove(key interface{}) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ if c.frequent.Remove(key) {
+ return
+ }
+ if c.recent.Remove(key) {
+ return
+ }
+ if c.recentEvict.Remove(key) {
+ return
+ }
+}
+
+// Purge is used to completely clear the cache.
+func (c *TwoQueueCache) Purge() {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ c.recent.Purge()
+ c.frequent.Purge()
+ c.recentEvict.Purge()
+}
+
+// Contains is used to check if the cache contains a key
+// without updating recency or frequency.
+func (c *TwoQueueCache) Contains(key interface{}) bool {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.frequent.Contains(key) || c.recent.Contains(key)
+}
+
+// Peek is used to inspect the cache value of a key
+// without updating recency or frequency.
+func (c *TwoQueueCache) Peek(key interface{}) (value interface{}, ok bool) {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ if val, ok := c.frequent.Peek(key); ok {
+ return val, ok
+ }
+ return c.recent.Peek(key)
+}
--- /dev/null
+package lru
+
+import (
+ "math/rand"
+ "testing"
+)
+
+func Benchmark2Q_Rand(b *testing.B) {
+ l, err := New2Q(8192)
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+
+ trace := make([]int64, b.N*2)
+ for i := 0; i < b.N*2; i++ {
+ trace[i] = rand.Int63() % 32768
+ }
+
+ b.ResetTimer()
+
+ var hit, miss int
+ for i := 0; i < 2*b.N; i++ {
+ if i%2 == 0 {
+ l.Add(trace[i], trace[i])
+ } else {
+ _, ok := l.Get(trace[i])
+ if ok {
+ hit++
+ } else {
+ miss++
+ }
+ }
+ }
+ b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
+}
+
+func Benchmark2Q_Freq(b *testing.B) {
+ l, err := New2Q(8192)
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+
+ trace := make([]int64, b.N*2)
+ for i := 0; i < b.N*2; i++ {
+ if i%2 == 0 {
+ trace[i] = rand.Int63() % 16384
+ } else {
+ trace[i] = rand.Int63() % 32768
+ }
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ l.Add(trace[i], trace[i])
+ }
+ var hit, miss int
+ for i := 0; i < b.N; i++ {
+ _, ok := l.Get(trace[i])
+ if ok {
+ hit++
+ } else {
+ miss++
+ }
+ }
+ b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
+}
+
+func Test2Q_RandomOps(t *testing.T) {
+ size := 128
+ l, err := New2Q(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ n := 200000
+ for i := 0; i < n; i++ {
+ key := rand.Int63() % 512
+ r := rand.Int63()
+ switch r % 3 {
+ case 0:
+ l.Add(key, key)
+ case 1:
+ l.Get(key)
+ case 2:
+ l.Remove(key)
+ }
+
+ if l.recent.Len()+l.frequent.Len() > size {
+ t.Fatalf("bad: recent: %d freq: %d",
+ l.recent.Len(), l.frequent.Len())
+ }
+ }
+}
+
+func Test2Q_Get_RecentToFrequent(t *testing.T) {
+ l, err := New2Q(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Touch all the entries, should be in t1
+ for i := 0; i < 128; i++ {
+ l.Add(i, i)
+ }
+ if n := l.recent.Len(); n != 128 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Get should upgrade to t2
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("missing: %d", i)
+ }
+ }
+ if n := l.recent.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 128 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Get be from t2
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("missing: %d", i)
+ }
+ }
+ if n := l.recent.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 128 {
+ t.Fatalf("bad: %d", n)
+ }
+}
+
+func Test2Q_Add_RecentToFrequent(t *testing.T) {
+ l, err := New2Q(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Add initially to recent
+ l.Add(1, 1)
+ if n := l.recent.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Add should upgrade to frequent
+ l.Add(1, 1)
+ if n := l.recent.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Add should remain in frequent
+ l.Add(1, 1)
+ if n := l.recent.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+}
+
+func Test2Q_Add_RecentEvict(t *testing.T) {
+ l, err := New2Q(4)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Add 1,2,3,4,5 -> Evict 1
+ l.Add(1, 1)
+ l.Add(2, 2)
+ l.Add(3, 3)
+ l.Add(4, 4)
+ l.Add(5, 5)
+ if n := l.recent.Len(); n != 4 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.recentEvict.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Pull in the recently evicted
+ l.Add(1, 1)
+ if n := l.recent.Len(); n != 3 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.recentEvict.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Add 6, should cause another recent evict
+ l.Add(6, 6)
+ if n := l.recent.Len(); n != 3 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.recentEvict.Len(); n != 2 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.frequent.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+}
+
+func Test2Q(t *testing.T) {
+ l, err := New2Q(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ for i := 0; i < 256; i++ {
+ l.Add(i, i)
+ }
+ if l.Len() != 128 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+
+ for i, k := range l.Keys() {
+ if v, ok := l.Get(k); !ok || v != k || v != i+128 {
+ t.Fatalf("bad key: %v", k)
+ }
+ }
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if ok {
+ t.Fatalf("should be evicted")
+ }
+ }
+ for i := 128; i < 256; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("should not be evicted")
+ }
+ }
+ for i := 128; i < 192; i++ {
+ l.Remove(i)
+ _, ok := l.Get(i)
+ if ok {
+ t.Fatalf("should be deleted")
+ }
+ }
+
+ l.Purge()
+ if l.Len() != 0 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+ if _, ok := l.Get(200); ok {
+ t.Fatalf("should contain nothing")
+ }
+}
+
+// Test that Contains doesn't update recent-ness
+func Test2Q_Contains(t *testing.T) {
+ l, err := New2Q(2)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if !l.Contains(1) {
+ t.Errorf("1 should be contained")
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("Contains should not have updated recent-ness of 1")
+ }
+}
+
+// Test that Peek doesn't update recent-ness
+func Test2Q_Peek(t *testing.T) {
+ l, err := New2Q(2)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if v, ok := l.Peek(1); !ok || v != 1 {
+ t.Errorf("1 should be set to 1: %v, %v", v, ok)
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("should not have updated recent-ness of 1")
+ }
+}
--- /dev/null
+Mozilla Public License, version 2.0
+
+1. Definitions
+
+1.1. "Contributor"
+
+ means each individual or legal entity that creates, contributes to the
+ creation of, or owns Covered Software.
+
+1.2. "Contributor Version"
+
+ means the combination of the Contributions of others (if any) used by a
+ Contributor and that particular Contributor's Contribution.
+
+1.3. "Contribution"
+
+ means Covered Software of a particular Contributor.
+
+1.4. "Covered Software"
+
+ means Source Code Form to which the initial Contributor has attached the
+ notice in Exhibit A, the Executable Form of such Source Code Form, and
+ Modifications of such Source Code Form, in each case including portions
+ thereof.
+
+1.5. "Incompatible With Secondary Licenses"
+ means
+
+ a. that the initial Contributor has attached the notice described in
+ Exhibit B to the Covered Software; or
+
+ b. that the Covered Software was made available under the terms of
+ version 1.1 or earlier of the License, but not also under the terms of
+ a Secondary License.
+
+1.6. "Executable Form"
+
+ means any form of the work other than Source Code Form.
+
+1.7. "Larger Work"
+
+ means a work that combines Covered Software with other material, in a
+ separate file or files, that is not Covered Software.
+
+1.8. "License"
+
+ means this document.
+
+1.9. "Licensable"
+
+ means having the right to grant, to the maximum extent possible, whether
+ at the time of the initial grant or subsequently, any and all of the
+ rights conveyed by this License.
+
+1.10. "Modifications"
+
+ means any of the following:
+
+ a. any file in Source Code Form that results from an addition to,
+ deletion from, or modification of the contents of Covered Software; or
+
+ b. any new file in Source Code Form that contains any Covered Software.
+
+1.11. "Patent Claims" of a Contributor
+
+ means any patent claim(s), including without limitation, method,
+ process, and apparatus claims, in any patent Licensable by such
+ Contributor that would be infringed, but for the grant of the License,
+ by the making, using, selling, offering for sale, having made, import,
+ or transfer of either its Contributions or its Contributor Version.
+
+1.12. "Secondary License"
+
+ means either the GNU General Public License, Version 2.0, the GNU Lesser
+ General Public License, Version 2.1, the GNU Affero General Public
+ License, Version 3.0, or any later versions of those licenses.
+
+1.13. "Source Code Form"
+
+ means the form of the work preferred for making modifications.
+
+1.14. "You" (or "Your")
+
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, "You" includes any entity that controls, is
+ controlled by, or is under common control with You. For purposes of this
+ definition, "control" means (a) the power, direct or indirect, to cause
+ the direction or management of such entity, whether by contract or
+ otherwise, or (b) ownership of more than fifty percent (50%) of the
+ outstanding shares or beneficial ownership of such entity.
+
+
+2. License Grants and Conditions
+
+2.1. Grants
+
+ Each Contributor hereby grants You a world-wide, royalty-free,
+ non-exclusive license:
+
+ a. under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or
+ as part of a Larger Work; and
+
+ b. under Patent Claims of such Contributor to make, use, sell, offer for
+ sale, have made, import, and otherwise transfer either its
+ Contributions or its Contributor Version.
+
+2.2. Effective Date
+
+ The licenses granted in Section 2.1 with respect to any Contribution
+ become effective for each Contribution on the date the Contributor first
+ distributes such Contribution.
+
+2.3. Limitations on Grant Scope
+
+ The licenses granted in this Section 2 are the only rights granted under
+ this License. No additional rights or licenses will be implied from the
+ distribution or licensing of Covered Software under this License.
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
+ Contributor:
+
+ a. for any code that a Contributor has removed from Covered Software; or
+
+ b. for infringements caused by: (i) Your and any other third party's
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+ c. under Patent Claims infringed by Covered Software in the absence of
+ its Contributions.
+
+ This License does not grant any rights in the trademarks, service marks,
+ or logos of any Contributor (except as may be necessary to comply with
+ the notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+ No Contributor makes additional grants as a result of Your choice to
+ distribute the Covered Software under a subsequent version of this
+ License (see Section 10.2) or under the terms of a Secondary License (if
+ permitted under the terms of Section 3.3).
+
+2.5. Representation
+
+ Each Contributor represents that the Contributor believes its
+ Contributions are its original creation(s) or it has sufficient rights to
+ grant the rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+ This License is not intended to limit any rights You have under
+ applicable copyright doctrines of fair use, fair dealing, or other
+ equivalents.
+
+2.7. Conditions
+
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
+ Section 2.1.
+
+
+3. Responsibilities
+
+3.1. Distribution of Source Form
+
+ All distribution of Covered Software in Source Code Form, including any
+ Modifications that You create or to which You contribute, must be under
+ the terms of this License. You must inform recipients that the Source
+ Code Form of the Covered Software is governed by the terms of this
+ License, and how they can obtain a copy of this License. You may not
+ attempt to alter or restrict the recipients' rights in the Source Code
+ Form.
+
+3.2. Distribution of Executable Form
+
+ If You distribute Covered Software in Executable Form then:
+
+ a. such Covered Software must also be made available in Source Code Form,
+ as described in Section 3.1, and You must inform recipients of the
+ Executable Form how they can obtain a copy of such Source Code Form by
+ reasonable means in a timely manner, at a charge no more than the cost
+ of distribution to the recipient; and
+
+ b. You may distribute such Executable Form under the terms of this
+ License, or sublicense it under different terms, provided that the
+ license for the Executable Form does not attempt to limit or alter the
+ recipients' rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+ You may create and distribute a Larger Work under terms of Your choice,
+ provided that You also comply with the requirements of this License for
+ the Covered Software. If the Larger Work is a combination of Covered
+ Software with a work governed by one or more Secondary Licenses, and the
+ Covered Software is not Incompatible With Secondary Licenses, this
+ License permits You to additionally distribute such Covered Software
+ under the terms of such Secondary License(s), so that the recipient of
+ the Larger Work may, at their option, further distribute the Covered
+ Software under the terms of either this License or such Secondary
+ License(s).
+
+3.4. Notices
+
+ You may not remove or alter the substance of any license notices
+ (including copyright notices, patent notices, disclaimers of warranty, or
+ limitations of liability) contained within the Source Code Form of the
+ Covered Software, except that You may alter any license notices to the
+ extent required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+ You may choose to offer, and to charge a fee for, warranty, support,
+ indemnity or liability obligations to one or more recipients of Covered
+ Software. However, You may do so only on Your own behalf, and not on
+ behalf of any Contributor. You must make it absolutely clear that any
+ such warranty, support, indemnity, or liability obligation is offered by
+ You alone, and You hereby agree to indemnify every Contributor for any
+ liability incurred by such Contributor as a result of warranty, support,
+ indemnity or liability terms You offer. You may include additional
+ disclaimers of warranty and limitations of liability specific to any
+ jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+
+ If it is impossible for You to comply with any of the terms of this License
+ with respect to some or all of the Covered Software due to statute,
+ judicial order, or regulation then You must: (a) comply with the terms of
+ this License to the maximum extent possible; and (b) describe the
+ limitations and the code they affect. Such description must be placed in a
+ text file included with all distributions of the Covered Software under
+ this License. Except to the extent prohibited by statute or regulation,
+ such description must be sufficiently detailed for a recipient of ordinary
+ skill to be able to understand it.
+
+5. Termination
+
+5.1. The rights granted under this License will terminate automatically if You
+ fail to comply with any of its terms. However, if You become compliant,
+ then the rights granted under this License from a particular Contributor
+ are reinstated (a) provisionally, unless and until such Contributor
+ explicitly and finally terminates Your grants, and (b) on an ongoing
+ basis, if such Contributor fails to notify You of the non-compliance by
+ some reasonable means prior to 60 days after You have come back into
+ compliance. Moreover, Your grants from a particular Contributor are
+ reinstated on an ongoing basis if such Contributor notifies You of the
+ non-compliance by some reasonable means, this is the first time You have
+ received notice of non-compliance with this License from such
+ Contributor, and You become compliant prior to 30 days after Your receipt
+ of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+ infringement claim (excluding declaratory judgment actions,
+ counter-claims, and cross-claims) alleging that a Contributor Version
+ directly or indirectly infringes any patent, then the rights granted to
+ You by any and all Contributors for the Covered Software under Section
+ 2.1 of this License shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
+ license agreements (excluding distributors and resellers) which have been
+ validly granted by You or Your distributors under this License prior to
+ termination shall survive termination.
+
+6. Disclaimer of Warranty
+
+ Covered Software is provided under this License on an "as is" basis,
+ without warranty of any kind, either expressed, implied, or statutory,
+ including, without limitation, warranties that the Covered Software is free
+ of defects, merchantable, fit for a particular purpose or non-infringing.
+ The entire risk as to the quality and performance of the Covered Software
+ is with You. Should any Covered Software prove defective in any respect,
+ You (not any Contributor) assume the cost of any necessary servicing,
+ repair, or correction. This disclaimer of warranty constitutes an essential
+ part of this License. No use of any Covered Software is authorized under
+ this License except under this disclaimer.
+
+7. Limitation of Liability
+
+ Under no circumstances and under no legal theory, whether tort (including
+ negligence), contract, or otherwise, shall any Contributor, or anyone who
+ distributes Covered Software as permitted above, be liable to You for any
+ direct, indirect, special, incidental, or consequential damages of any
+ character including, without limitation, damages for lost profits, loss of
+ goodwill, work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses, even if such party shall have been
+ informed of the possibility of such damages. This limitation of liability
+ shall not apply to liability for death or personal injury resulting from
+ such party's negligence to the extent applicable law prohibits such
+ limitation. Some jurisdictions do not allow the exclusion or limitation of
+ incidental or consequential damages, so this exclusion and limitation may
+ not apply to You.
+
+8. Litigation
+
+ Any litigation relating to this License may be brought only in the courts
+ of a jurisdiction where the defendant maintains its principal place of
+ business and such litigation shall be governed by laws of that
+ jurisdiction, without reference to its conflict-of-law provisions. Nothing
+ in this Section shall prevent a party's ability to bring cross-claims or
+ counter-claims.
+
+9. Miscellaneous
+
+ This License represents the complete agreement concerning the subject
+ matter hereof. If any provision of this License is held to be
+ unenforceable, such provision shall be reformed only to the extent
+ necessary to make it enforceable. Any law or regulation which provides that
+ the language of a contract shall be construed against the drafter shall not
+ be used to construe this License against a Contributor.
+
+
+10. Versions of the License
+
+10.1. New Versions
+
+ Mozilla Foundation is the license steward. Except as provided in Section
+ 10.3, no one other than the license steward has the right to modify or
+ publish new versions of this License. Each version will be given a
+ distinguishing version number.
+
+10.2. Effect of New Versions
+
+ You may distribute the Covered Software under the terms of the version
+ of the License under which You originally received the Covered Software,
+ or under the terms of any subsequent version published by the license
+ steward.
+
+10.3. Modified Versions
+
+ If you create software not governed by this License, and you want to
+ create a new license for such software, you may create and use a
+ modified version of this License if you rename the license and remove
+ any references to the name of the license steward (except to note that
+ such modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary
+ Licenses If You choose to distribute Source Code Form that is
+ Incompatible With Secondary Licenses under the terms of this version of
+ the License, the notice described in Exhibit B of this License must be
+ attached.
+
+Exhibit A - Source Code Form License Notice
+
+ This Source Code Form is subject to the
+ terms of the Mozilla Public License, v.
+ 2.0. If a copy of the MPL was not
+ distributed with this file, You can
+ obtain one at
+ http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular file,
+then You may include the notice in a location (such as a LICENSE file in a
+relevant directory) where a recipient would be likely to look for such a
+notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - "Incompatible With Secondary Licenses" Notice
+
+ This Source Code Form is "Incompatible
+ With Secondary Licenses", as defined by
+ the Mozilla Public License, v. 2.0.
--- /dev/null
+golang-lru
+==========
+
+This provides the `lru` package which implements a fixed-size
+thread safe LRU cache. It is based on the cache in Groupcache.
+
+Documentation
+=============
+
+Full docs are available on [Godoc](http://godoc.org/github.com/hashicorp/golang-lru)
+
+Example
+=======
+
+Using the LRU is very simple:
+
+```go
+l, _ := New(128)
+for i := 0; i < 256; i++ {
+ l.Add(i, nil)
+}
+if l.Len() != 128 {
+ panic(fmt.Sprintf("bad len: %v", l.Len()))
+}
+```
--- /dev/null
+package lru
+
+import (
+ "sync"
+
+ "github.com/hashicorp/golang-lru/simplelru"
+)
+
+// ARCCache is a thread-safe fixed size Adaptive Replacement Cache (ARC).
+// ARC is an enhancement over the standard LRU cache in that tracks both
+// frequency and recency of use. This avoids a burst in access to new
+// entries from evicting the frequently used older entries. It adds some
+// additional tracking overhead to a standard LRU cache, computationally
+// it is roughly 2x the cost, and the extra memory overhead is linear
+// with the size of the cache. ARC has been patented by IBM, but is
+// similar to the TwoQueueCache (2Q) which requires setting parameters.
+type ARCCache struct {
+ size int // Size is the total capacity of the cache
+ p int // P is the dynamic preference towards T1 or T2
+
+ t1 simplelru.LRUCache // T1 is the LRU for recently accessed items
+ b1 simplelru.LRUCache // B1 is the LRU for evictions from t1
+
+ t2 simplelru.LRUCache // T2 is the LRU for frequently accessed items
+ b2 simplelru.LRUCache // B2 is the LRU for evictions from t2
+
+ lock sync.RWMutex
+}
+
+// NewARC creates an ARC of the given size
+func NewARC(size int) (*ARCCache, error) {
+ // Create the sub LRUs
+ b1, err := simplelru.NewLRU(size, nil)
+ if err != nil {
+ return nil, err
+ }
+ b2, err := simplelru.NewLRU(size, nil)
+ if err != nil {
+ return nil, err
+ }
+ t1, err := simplelru.NewLRU(size, nil)
+ if err != nil {
+ return nil, err
+ }
+ t2, err := simplelru.NewLRU(size, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // Initialize the ARC
+ c := &ARCCache{
+ size: size,
+ p: 0,
+ t1: t1,
+ b1: b1,
+ t2: t2,
+ b2: b2,
+ }
+ return c, nil
+}
+
+// Get looks up a key's value from the cache.
+func (c *ARCCache) Get(key interface{}) (value interface{}, ok bool) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ // If the value is contained in T1 (recent), then
+ // promote it to T2 (frequent)
+ if val, ok := c.t1.Peek(key); ok {
+ c.t1.Remove(key)
+ c.t2.Add(key, val)
+ return val, ok
+ }
+
+ // Check if the value is contained in T2 (frequent)
+ if val, ok := c.t2.Get(key); ok {
+ return val, ok
+ }
+
+ // No hit
+ return nil, false
+}
+
+// Add adds a value to the cache.
+func (c *ARCCache) Add(key, value interface{}) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ // Check if the value is contained in T1 (recent), and potentially
+ // promote it to frequent T2
+ if c.t1.Contains(key) {
+ c.t1.Remove(key)
+ c.t2.Add(key, value)
+ return
+ }
+
+ // Check if the value is already in T2 (frequent) and update it
+ if c.t2.Contains(key) {
+ c.t2.Add(key, value)
+ return
+ }
+
+ // Check if this value was recently evicted as part of the
+ // recently used list
+ if c.b1.Contains(key) {
+ // T1 set is too small, increase P appropriately
+ delta := 1
+ b1Len := c.b1.Len()
+ b2Len := c.b2.Len()
+ if b2Len > b1Len {
+ delta = b2Len / b1Len
+ }
+ if c.p+delta >= c.size {
+ c.p = c.size
+ } else {
+ c.p += delta
+ }
+
+ // Potentially need to make room in the cache
+ if c.t1.Len()+c.t2.Len() >= c.size {
+ c.replace(false)
+ }
+
+ // Remove from B1
+ c.b1.Remove(key)
+
+ // Add the key to the frequently used list
+ c.t2.Add(key, value)
+ return
+ }
+
+ // Check if this value was recently evicted as part of the
+ // frequently used list
+ if c.b2.Contains(key) {
+ // T2 set is too small, decrease P appropriately
+ delta := 1
+ b1Len := c.b1.Len()
+ b2Len := c.b2.Len()
+ if b1Len > b2Len {
+ delta = b1Len / b2Len
+ }
+ if delta >= c.p {
+ c.p = 0
+ } else {
+ c.p -= delta
+ }
+
+ // Potentially need to make room in the cache
+ if c.t1.Len()+c.t2.Len() >= c.size {
+ c.replace(true)
+ }
+
+ // Remove from B2
+ c.b2.Remove(key)
+
+ // Add the key to the frequently used list
+ c.t2.Add(key, value)
+ return
+ }
+
+ // Potentially need to make room in the cache
+ if c.t1.Len()+c.t2.Len() >= c.size {
+ c.replace(false)
+ }
+
+ // Keep the size of the ghost buffers trim
+ if c.b1.Len() > c.size-c.p {
+ c.b1.RemoveOldest()
+ }
+ if c.b2.Len() > c.p {
+ c.b2.RemoveOldest()
+ }
+
+ // Add to the recently seen list
+ c.t1.Add(key, value)
+ return
+}
+
+// replace is used to adaptively evict from either T1 or T2
+// based on the current learned value of P
+func (c *ARCCache) replace(b2ContainsKey bool) {
+ t1Len := c.t1.Len()
+ if t1Len > 0 && (t1Len > c.p || (t1Len == c.p && b2ContainsKey)) {
+ k, _, ok := c.t1.RemoveOldest()
+ if ok {
+ c.b1.Add(k, nil)
+ }
+ } else {
+ k, _, ok := c.t2.RemoveOldest()
+ if ok {
+ c.b2.Add(k, nil)
+ }
+ }
+}
+
+// Len returns the number of cached entries
+func (c *ARCCache) Len() int {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.t1.Len() + c.t2.Len()
+}
+
+// Keys returns all the cached keys
+func (c *ARCCache) Keys() []interface{} {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ k1 := c.t1.Keys()
+ k2 := c.t2.Keys()
+ return append(k1, k2...)
+}
+
+// Remove is used to purge a key from the cache
+func (c *ARCCache) Remove(key interface{}) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ if c.t1.Remove(key) {
+ return
+ }
+ if c.t2.Remove(key) {
+ return
+ }
+ if c.b1.Remove(key) {
+ return
+ }
+ if c.b2.Remove(key) {
+ return
+ }
+}
+
+// Purge is used to clear the cache
+func (c *ARCCache) Purge() {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ c.t1.Purge()
+ c.t2.Purge()
+ c.b1.Purge()
+ c.b2.Purge()
+}
+
+// Contains is used to check if the cache contains a key
+// without updating recency or frequency.
+func (c *ARCCache) Contains(key interface{}) bool {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.t1.Contains(key) || c.t2.Contains(key)
+}
+
+// Peek is used to inspect the cache value of a key
+// without updating recency or frequency.
+func (c *ARCCache) Peek(key interface{}) (value interface{}, ok bool) {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ if val, ok := c.t1.Peek(key); ok {
+ return val, ok
+ }
+ return c.t2.Peek(key)
+}
--- /dev/null
+package lru
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+)
+
+func init() {
+ rand.Seed(time.Now().Unix())
+}
+
+func BenchmarkARC_Rand(b *testing.B) {
+ l, err := NewARC(8192)
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+
+ trace := make([]int64, b.N*2)
+ for i := 0; i < b.N*2; i++ {
+ trace[i] = rand.Int63() % 32768
+ }
+
+ b.ResetTimer()
+
+ var hit, miss int
+ for i := 0; i < 2*b.N; i++ {
+ if i%2 == 0 {
+ l.Add(trace[i], trace[i])
+ } else {
+ _, ok := l.Get(trace[i])
+ if ok {
+ hit++
+ } else {
+ miss++
+ }
+ }
+ }
+ b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
+}
+
+func BenchmarkARC_Freq(b *testing.B) {
+ l, err := NewARC(8192)
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+
+ trace := make([]int64, b.N*2)
+ for i := 0; i < b.N*2; i++ {
+ if i%2 == 0 {
+ trace[i] = rand.Int63() % 16384
+ } else {
+ trace[i] = rand.Int63() % 32768
+ }
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ l.Add(trace[i], trace[i])
+ }
+ var hit, miss int
+ for i := 0; i < b.N; i++ {
+ _, ok := l.Get(trace[i])
+ if ok {
+ hit++
+ } else {
+ miss++
+ }
+ }
+ b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
+}
+
+func TestARC_RandomOps(t *testing.T) {
+ size := 128
+ l, err := NewARC(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ n := 200000
+ for i := 0; i < n; i++ {
+ key := rand.Int63() % 512
+ r := rand.Int63()
+ switch r % 3 {
+ case 0:
+ l.Add(key, key)
+ case 1:
+ l.Get(key)
+ case 2:
+ l.Remove(key)
+ }
+
+ if l.t1.Len()+l.t2.Len() > size {
+ t.Fatalf("bad: t1: %d t2: %d b1: %d b2: %d p: %d",
+ l.t1.Len(), l.t2.Len(), l.b1.Len(), l.b2.Len(), l.p)
+ }
+ if l.b1.Len()+l.b2.Len() > size {
+ t.Fatalf("bad: t1: %d t2: %d b1: %d b2: %d p: %d",
+ l.t1.Len(), l.t2.Len(), l.b1.Len(), l.b2.Len(), l.p)
+ }
+ }
+}
+
+func TestARC_Get_RecentToFrequent(t *testing.T) {
+ l, err := NewARC(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Touch all the entries, should be in t1
+ for i := 0; i < 128; i++ {
+ l.Add(i, i)
+ }
+ if n := l.t1.Len(); n != 128 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Get should upgrade to t2
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("missing: %d", i)
+ }
+ }
+ if n := l.t1.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 128 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Get be from t2
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("missing: %d", i)
+ }
+ }
+ if n := l.t1.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 128 {
+ t.Fatalf("bad: %d", n)
+ }
+}
+
+func TestARC_Add_RecentToFrequent(t *testing.T) {
+ l, err := NewARC(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Add initially to t1
+ l.Add(1, 1)
+ if n := l.t1.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Add should upgrade to t2
+ l.Add(1, 1)
+ if n := l.t1.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Add should remain in t2
+ l.Add(1, 1)
+ if n := l.t1.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+}
+
+func TestARC_Adaptive(t *testing.T) {
+ l, err := NewARC(4)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ // Fill t1
+ for i := 0; i < 4; i++ {
+ l.Add(i, i)
+ }
+ if n := l.t1.Len(); n != 4 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Move to t2
+ l.Get(0)
+ l.Get(1)
+ if n := l.t2.Len(); n != 2 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Evict from t1
+ l.Add(4, 4)
+ if n := l.b1.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Current state
+ // t1 : (MRU) [4, 3] (LRU)
+ // t2 : (MRU) [1, 0] (LRU)
+ // b1 : (MRU) [2] (LRU)
+ // b2 : (MRU) [] (LRU)
+
+ // Add 2, should cause hit on b1
+ l.Add(2, 2)
+ if n := l.b1.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+ if l.p != 1 {
+ t.Fatalf("bad: %d", l.p)
+ }
+ if n := l.t2.Len(); n != 3 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Current state
+ // t1 : (MRU) [4] (LRU)
+ // t2 : (MRU) [2, 1, 0] (LRU)
+ // b1 : (MRU) [3] (LRU)
+ // b2 : (MRU) [] (LRU)
+
+ // Add 4, should migrate to t2
+ l.Add(4, 4)
+ if n := l.t1.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 4 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Current state
+ // t1 : (MRU) [] (LRU)
+ // t2 : (MRU) [4, 2, 1, 0] (LRU)
+ // b1 : (MRU) [3] (LRU)
+ // b2 : (MRU) [] (LRU)
+
+ // Add 4, should evict to b2
+ l.Add(5, 5)
+ if n := l.t1.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 3 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.b2.Len(); n != 1 {
+ t.Fatalf("bad: %d", n)
+ }
+
+ // Current state
+ // t1 : (MRU) [5] (LRU)
+ // t2 : (MRU) [4, 2, 1] (LRU)
+ // b1 : (MRU) [3] (LRU)
+ // b2 : (MRU) [0] (LRU)
+
+ // Add 0, should decrease p
+ l.Add(0, 0)
+ if n := l.t1.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.t2.Len(); n != 4 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.b1.Len(); n != 2 {
+ t.Fatalf("bad: %d", n)
+ }
+ if n := l.b2.Len(); n != 0 {
+ t.Fatalf("bad: %d", n)
+ }
+ if l.p != 0 {
+ t.Fatalf("bad: %d", l.p)
+ }
+
+ // Current state
+ // t1 : (MRU) [] (LRU)
+ // t2 : (MRU) [0, 4, 2, 1] (LRU)
+ // b1 : (MRU) [5, 3] (LRU)
+ // b2 : (MRU) [0] (LRU)
+}
+
+func TestARC(t *testing.T) {
+ l, err := NewARC(128)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ for i := 0; i < 256; i++ {
+ l.Add(i, i)
+ }
+ if l.Len() != 128 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+
+ for i, k := range l.Keys() {
+ if v, ok := l.Get(k); !ok || v != k || v != i+128 {
+ t.Fatalf("bad key: %v", k)
+ }
+ }
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if ok {
+ t.Fatalf("should be evicted")
+ }
+ }
+ for i := 128; i < 256; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("should not be evicted")
+ }
+ }
+ for i := 128; i < 192; i++ {
+ l.Remove(i)
+ _, ok := l.Get(i)
+ if ok {
+ t.Fatalf("should be deleted")
+ }
+ }
+
+ l.Purge()
+ if l.Len() != 0 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+ if _, ok := l.Get(200); ok {
+ t.Fatalf("should contain nothing")
+ }
+}
+
+// Test that Contains doesn't update recent-ness
+func TestARC_Contains(t *testing.T) {
+ l, err := NewARC(2)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if !l.Contains(1) {
+ t.Errorf("1 should be contained")
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("Contains should not have updated recent-ness of 1")
+ }
+}
+
+// Test that Peek doesn't update recent-ness
+func TestARC_Peek(t *testing.T) {
+ l, err := NewARC(2)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if v, ok := l.Peek(1); !ok || v != 1 {
+ t.Errorf("1 should be set to 1: %v, %v", v, ok)
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("should not have updated recent-ness of 1")
+ }
+}
--- /dev/null
+// Package lru provides three different LRU caches of varying sophistication.
+//
+// Cache is a simple LRU cache. It is based on the
+// LRU implementation in groupcache:
+// https://github.com/golang/groupcache/tree/master/lru
+//
+// TwoQueueCache tracks frequently used and recently used entries separately.
+// This avoids a burst of accesses from taking out frequently used entries,
+// at the cost of about 2x computational overhead and some extra bookkeeping.
+//
+// ARCCache is an adaptive replacement cache. It tracks recent evictions as
+// well as recent usage in both the frequent and recent caches. Its
+// computational overhead is comparable to TwoQueueCache, but the memory
+// overhead is linear with the size of the cache.
+//
+// ARC has been patented by IBM, so do not use it if that is problematic for
+// your program.
+//
+// All caches in this package take locks while operating, and are therefore
+// thread-safe for consumers.
+package lru
--- /dev/null
+module github.com/hashicorp/golang-lru
--- /dev/null
+package lru
+
+import (
+ "sync"
+
+ "github.com/hashicorp/golang-lru/simplelru"
+)
+
+// Cache is a thread-safe fixed size LRU cache.
+type Cache struct {
+ lru simplelru.LRUCache
+ lock sync.RWMutex
+}
+
+// New creates an LRU of the given size.
+func New(size int) (*Cache, error) {
+ return NewWithEvict(size, nil)
+}
+
+// NewWithEvict constructs a fixed size cache with the given eviction
+// callback.
+func NewWithEvict(size int, onEvicted func(key interface{}, value interface{})) (*Cache, error) {
+ lru, err := simplelru.NewLRU(size, simplelru.EvictCallback(onEvicted))
+ if err != nil {
+ return nil, err
+ }
+ c := &Cache{
+ lru: lru,
+ }
+ return c, nil
+}
+
+// Purge is used to completely clear the cache.
+func (c *Cache) Purge() {
+ c.lock.Lock()
+ c.lru.Purge()
+ c.lock.Unlock()
+}
+
+// Add adds a value to the cache. Returns true if an eviction occurred.
+func (c *Cache) Add(key, value interface{}) (evicted bool) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ return c.lru.Add(key, value)
+}
+
+// Get looks up a key's value from the cache.
+func (c *Cache) Get(key interface{}) (value interface{}, ok bool) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ return c.lru.Get(key)
+}
+
+// Contains checks if a key is in the cache, without updating the
+// recent-ness or deleting it for being stale.
+func (c *Cache) Contains(key interface{}) bool {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.lru.Contains(key)
+}
+
+// Peek returns the key value (or undefined if not found) without updating
+// the "recently used"-ness of the key.
+func (c *Cache) Peek(key interface{}) (value interface{}, ok bool) {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.lru.Peek(key)
+}
+
+// ContainsOrAdd checks if a key is in the cache without updating the
+// recent-ness or deleting it for being stale, and if not, adds the value.
+// Returns whether found and whether an eviction occurred.
+func (c *Cache) ContainsOrAdd(key, value interface{}) (ok, evicted bool) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.lru.Contains(key) {
+ return true, false
+ }
+ evicted = c.lru.Add(key, value)
+ return false, evicted
+}
+
+// Remove removes the provided key from the cache.
+func (c *Cache) Remove(key interface{}) {
+ c.lock.Lock()
+ c.lru.Remove(key)
+ c.lock.Unlock()
+}
+
+// RemoveOldest removes the oldest item from the cache.
+func (c *Cache) RemoveOldest() {
+ c.lock.Lock()
+ c.lru.RemoveOldest()
+ c.lock.Unlock()
+}
+
+// Keys returns a slice of the keys in the cache, from oldest to newest.
+func (c *Cache) Keys() []interface{} {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.lru.Keys()
+}
+
+// Len returns the number of items in the cache.
+func (c *Cache) Len() int {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+ return c.lru.Len()
+}
--- /dev/null
+package lru
+
+import (
+ "math/rand"
+ "testing"
+)
+
+func BenchmarkLRU_Rand(b *testing.B) {
+ l, err := New(8192)
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+
+ trace := make([]int64, b.N*2)
+ for i := 0; i < b.N*2; i++ {
+ trace[i] = rand.Int63() % 32768
+ }
+
+ b.ResetTimer()
+
+ var hit, miss int
+ for i := 0; i < 2*b.N; i++ {
+ if i%2 == 0 {
+ l.Add(trace[i], trace[i])
+ } else {
+ _, ok := l.Get(trace[i])
+ if ok {
+ hit++
+ } else {
+ miss++
+ }
+ }
+ }
+ b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
+}
+
+func BenchmarkLRU_Freq(b *testing.B) {
+ l, err := New(8192)
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+
+ trace := make([]int64, b.N*2)
+ for i := 0; i < b.N*2; i++ {
+ if i%2 == 0 {
+ trace[i] = rand.Int63() % 16384
+ } else {
+ trace[i] = rand.Int63() % 32768
+ }
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ l.Add(trace[i], trace[i])
+ }
+ var hit, miss int
+ for i := 0; i < b.N; i++ {
+ _, ok := l.Get(trace[i])
+ if ok {
+ hit++
+ } else {
+ miss++
+ }
+ }
+ b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
+}
+
+func TestLRU(t *testing.T) {
+ evictCounter := 0
+ onEvicted := func(k interface{}, v interface{}) {
+ if k != v {
+ t.Fatalf("Evict values not equal (%v!=%v)", k, v)
+ }
+ evictCounter++
+ }
+ l, err := NewWithEvict(128, onEvicted)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ for i := 0; i < 256; i++ {
+ l.Add(i, i)
+ }
+ if l.Len() != 128 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+
+ if evictCounter != 128 {
+ t.Fatalf("bad evict count: %v", evictCounter)
+ }
+
+ for i, k := range l.Keys() {
+ if v, ok := l.Get(k); !ok || v != k || v != i+128 {
+ t.Fatalf("bad key: %v", k)
+ }
+ }
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if ok {
+ t.Fatalf("should be evicted")
+ }
+ }
+ for i := 128; i < 256; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("should not be evicted")
+ }
+ }
+ for i := 128; i < 192; i++ {
+ l.Remove(i)
+ _, ok := l.Get(i)
+ if ok {
+ t.Fatalf("should be deleted")
+ }
+ }
+
+ l.Get(192) // expect 192 to be last key in l.Keys()
+
+ for i, k := range l.Keys() {
+ if (i < 63 && k != i+193) || (i == 63 && k != 192) {
+ t.Fatalf("out of order key: %v", k)
+ }
+ }
+
+ l.Purge()
+ if l.Len() != 0 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+ if _, ok := l.Get(200); ok {
+ t.Fatalf("should contain nothing")
+ }
+}
+
+// test that Add returns true/false if an eviction occurred
+func TestLRUAdd(t *testing.T) {
+ evictCounter := 0
+ onEvicted := func(k interface{}, v interface{}) {
+ evictCounter++
+ }
+
+ l, err := NewWithEvict(1, onEvicted)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ if l.Add(1, 1) == true || evictCounter != 0 {
+ t.Errorf("should not have an eviction")
+ }
+ if l.Add(2, 2) == false || evictCounter != 1 {
+ t.Errorf("should have an eviction")
+ }
+}
+
+// test that Contains doesn't update recent-ness
+func TestLRUContains(t *testing.T) {
+ l, err := New(2)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if !l.Contains(1) {
+ t.Errorf("1 should be contained")
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("Contains should not have updated recent-ness of 1")
+ }
+}
+
+// test that Contains doesn't update recent-ness
+func TestLRUContainsOrAdd(t *testing.T) {
+ l, err := New(2)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ contains, evict := l.ContainsOrAdd(1, 1)
+ if !contains {
+ t.Errorf("1 should be contained")
+ }
+ if evict {
+ t.Errorf("nothing should be evicted here")
+ }
+
+ l.Add(3, 3)
+ contains, evict = l.ContainsOrAdd(1, 1)
+ if contains {
+ t.Errorf("1 should not have been contained")
+ }
+ if !evict {
+ t.Errorf("an eviction should have occurred")
+ }
+ if !l.Contains(1) {
+ t.Errorf("now 1 should be contained")
+ }
+}
+
+// test that Peek doesn't update recent-ness
+func TestLRUPeek(t *testing.T) {
+ l, err := New(2)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if v, ok := l.Peek(1); !ok || v != 1 {
+ t.Errorf("1 should be set to 1: %v, %v", v, ok)
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("should not have updated recent-ness of 1")
+ }
+}
--- /dev/null
+package simplelru
+
+import (
+ "container/list"
+ "errors"
+)
+
+// EvictCallback is used to get a callback when a cache entry is evicted
+type EvictCallback func(key interface{}, value interface{})
+
+// LRU implements a non-thread safe fixed size LRU cache
+type LRU struct {
+ size int
+ evictList *list.List
+ items map[interface{}]*list.Element
+ onEvict EvictCallback
+}
+
+// entry is used to hold a value in the evictList
+type entry struct {
+ key interface{}
+ value interface{}
+}
+
+// NewLRU constructs an LRU of the given size
+func NewLRU(size int, onEvict EvictCallback) (*LRU, error) {
+ if size <= 0 {
+ return nil, errors.New("Must provide a positive size")
+ }
+ c := &LRU{
+ size: size,
+ evictList: list.New(),
+ items: make(map[interface{}]*list.Element),
+ onEvict: onEvict,
+ }
+ return c, nil
+}
+
+// Purge is used to completely clear the cache.
+func (c *LRU) Purge() {
+ for k, v := range c.items {
+ if c.onEvict != nil {
+ c.onEvict(k, v.Value.(*entry).value)
+ }
+ delete(c.items, k)
+ }
+ c.evictList.Init()
+}
+
+// Add adds a value to the cache. Returns true if an eviction occurred.
+func (c *LRU) Add(key, value interface{}) (evicted bool) {
+ // Check for existing item
+ if ent, ok := c.items[key]; ok {
+ c.evictList.MoveToFront(ent)
+ ent.Value.(*entry).value = value
+ return false
+ }
+
+ // Add new item
+ ent := &entry{key, value}
+ entry := c.evictList.PushFront(ent)
+ c.items[key] = entry
+
+ evict := c.evictList.Len() > c.size
+ // Verify size not exceeded
+ if evict {
+ c.removeOldest()
+ }
+ return evict
+}
+
+// Get looks up a key's value from the cache.
+func (c *LRU) Get(key interface{}) (value interface{}, ok bool) {
+ if ent, ok := c.items[key]; ok {
+ c.evictList.MoveToFront(ent)
+ return ent.Value.(*entry).value, true
+ }
+ return
+}
+
+// Contains checks if a key is in the cache, without updating the recent-ness
+// or deleting it for being stale.
+func (c *LRU) Contains(key interface{}) (ok bool) {
+ _, ok = c.items[key]
+ return ok
+}
+
+// Peek returns the key value (or undefined if not found) without updating
+// the "recently used"-ness of the key.
+func (c *LRU) Peek(key interface{}) (value interface{}, ok bool) {
+ var ent *list.Element
+ if ent, ok = c.items[key]; ok {
+ return ent.Value.(*entry).value, true
+ }
+ return nil, ok
+}
+
+// Remove removes the provided key from the cache, returning if the
+// key was contained.
+func (c *LRU) Remove(key interface{}) (present bool) {
+ if ent, ok := c.items[key]; ok {
+ c.removeElement(ent)
+ return true
+ }
+ return false
+}
+
+// RemoveOldest removes the oldest item from the cache.
+func (c *LRU) RemoveOldest() (key interface{}, value interface{}, ok bool) {
+ ent := c.evictList.Back()
+ if ent != nil {
+ c.removeElement(ent)
+ kv := ent.Value.(*entry)
+ return kv.key, kv.value, true
+ }
+ return nil, nil, false
+}
+
+// GetOldest returns the oldest entry
+func (c *LRU) GetOldest() (key interface{}, value interface{}, ok bool) {
+ ent := c.evictList.Back()
+ if ent != nil {
+ kv := ent.Value.(*entry)
+ return kv.key, kv.value, true
+ }
+ return nil, nil, false
+}
+
+// Keys returns a slice of the keys in the cache, from oldest to newest.
+func (c *LRU) Keys() []interface{} {
+ keys := make([]interface{}, len(c.items))
+ i := 0
+ for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() {
+ keys[i] = ent.Value.(*entry).key
+ i++
+ }
+ return keys
+}
+
+// Len returns the number of items in the cache.
+func (c *LRU) Len() int {
+ return c.evictList.Len()
+}
+
+// removeOldest removes the oldest item from the cache.
+func (c *LRU) removeOldest() {
+ ent := c.evictList.Back()
+ if ent != nil {
+ c.removeElement(ent)
+ }
+}
+
+// removeElement is used to remove a given list element from the cache
+func (c *LRU) removeElement(e *list.Element) {
+ c.evictList.Remove(e)
+ kv := e.Value.(*entry)
+ delete(c.items, kv.key)
+ if c.onEvict != nil {
+ c.onEvict(kv.key, kv.value)
+ }
+}
--- /dev/null
+package simplelru
+
+// LRUCache is the interface for simple LRU cache.
+type LRUCache interface {
+ // Adds a value to the cache, returns true if an eviction occurred and
+ // updates the "recently used"-ness of the key.
+ Add(key, value interface{}) bool
+
+ // Returns key's value from the cache and
+ // updates the "recently used"-ness of the key. #value, isFound
+ Get(key interface{}) (value interface{}, ok bool)
+
+ // Check if a key exsists in cache without updating the recent-ness.
+ Contains(key interface{}) (ok bool)
+
+ // Returns key's value without updating the "recently used"-ness of the key.
+ Peek(key interface{}) (value interface{}, ok bool)
+
+ // Removes a key from the cache.
+ Remove(key interface{}) bool
+
+ // Removes the oldest entry from cache.
+ RemoveOldest() (interface{}, interface{}, bool)
+
+ // Returns the oldest entry from the cache. #key, value, isFound
+ GetOldest() (interface{}, interface{}, bool)
+
+ // Returns a slice of the keys in the cache, from oldest to newest.
+ Keys() []interface{}
+
+ // Returns the number of items in the cache.
+ Len() int
+
+ // Clear all cache entries
+ Purge()
+}
--- /dev/null
+package simplelru
+
+import "testing"
+
+func TestLRU(t *testing.T) {
+ evictCounter := 0
+ onEvicted := func(k interface{}, v interface{}) {
+ if k != v {
+ t.Fatalf("Evict values not equal (%v!=%v)", k, v)
+ }
+ evictCounter++
+ }
+ l, err := NewLRU(128, onEvicted)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ for i := 0; i < 256; i++ {
+ l.Add(i, i)
+ }
+ if l.Len() != 128 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+
+ if evictCounter != 128 {
+ t.Fatalf("bad evict count: %v", evictCounter)
+ }
+
+ for i, k := range l.Keys() {
+ if v, ok := l.Get(k); !ok || v != k || v != i+128 {
+ t.Fatalf("bad key: %v", k)
+ }
+ }
+ for i := 0; i < 128; i++ {
+ _, ok := l.Get(i)
+ if ok {
+ t.Fatalf("should be evicted")
+ }
+ }
+ for i := 128; i < 256; i++ {
+ _, ok := l.Get(i)
+ if !ok {
+ t.Fatalf("should not be evicted")
+ }
+ }
+ for i := 128; i < 192; i++ {
+ ok := l.Remove(i)
+ if !ok {
+ t.Fatalf("should be contained")
+ }
+ ok = l.Remove(i)
+ if ok {
+ t.Fatalf("should not be contained")
+ }
+ _, ok = l.Get(i)
+ if ok {
+ t.Fatalf("should be deleted")
+ }
+ }
+
+ l.Get(192) // expect 192 to be last key in l.Keys()
+
+ for i, k := range l.Keys() {
+ if (i < 63 && k != i+193) || (i == 63 && k != 192) {
+ t.Fatalf("out of order key: %v", k)
+ }
+ }
+
+ l.Purge()
+ if l.Len() != 0 {
+ t.Fatalf("bad len: %v", l.Len())
+ }
+ if _, ok := l.Get(200); ok {
+ t.Fatalf("should contain nothing")
+ }
+}
+
+func TestLRU_GetOldest_RemoveOldest(t *testing.T) {
+ l, err := NewLRU(128, nil)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ for i := 0; i < 256; i++ {
+ l.Add(i, i)
+ }
+ k, _, ok := l.GetOldest()
+ if !ok {
+ t.Fatalf("missing")
+ }
+ if k.(int) != 128 {
+ t.Fatalf("bad: %v", k)
+ }
+
+ k, _, ok = l.RemoveOldest()
+ if !ok {
+ t.Fatalf("missing")
+ }
+ if k.(int) != 128 {
+ t.Fatalf("bad: %v", k)
+ }
+
+ k, _, ok = l.RemoveOldest()
+ if !ok {
+ t.Fatalf("missing")
+ }
+ if k.(int) != 129 {
+ t.Fatalf("bad: %v", k)
+ }
+}
+
+// Test that Add returns true/false if an eviction occurred
+func TestLRU_Add(t *testing.T) {
+ evictCounter := 0
+ onEvicted := func(k interface{}, v interface{}) {
+ evictCounter++
+ }
+
+ l, err := NewLRU(1, onEvicted)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ if l.Add(1, 1) == true || evictCounter != 0 {
+ t.Errorf("should not have an eviction")
+ }
+ if l.Add(2, 2) == false || evictCounter != 1 {
+ t.Errorf("should have an eviction")
+ }
+}
+
+// Test that Contains doesn't update recent-ness
+func TestLRU_Contains(t *testing.T) {
+ l, err := NewLRU(2, nil)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if !l.Contains(1) {
+ t.Errorf("1 should be contained")
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("Contains should not have updated recent-ness of 1")
+ }
+}
+
+// Test that Peek doesn't update recent-ness
+func TestLRU_Peek(t *testing.T) {
+ l, err := NewLRU(2, nil)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ l.Add(1, 1)
+ l.Add(2, 2)
+ if v, ok := l.Peek(1); !ok || v != 1 {
+ t.Errorf("1 should be set to 1: %v, %v", v, ok)
+ }
+
+ l.Add(3, 3)
+ if l.Contains(1) {
+ t.Errorf("should not have updated recent-ness of 1")
+ }
+}
--- /dev/null
+# Compiled Object files, Static and Dynamic libs (Shared Objects)
+*.o
+*.a
+*.so
+
+# Folders
+_obj
+_test
+
+# Architecture specific extensions/prefixes
+*.[568vq]
+[568vq].out
+
+*.cgo1.go
+*.cgo2.c
+_cgo_defun.c
+_cgo_gotypes.go
+_cgo_export.*
+
+_testmain.go
+
+*.exe
+*.test
--- /dev/null
+Mozilla Public License, version 2.0
+
+1. Definitions
+
+1.1. "Contributor"
+
+ means each individual or legal entity that creates, contributes to the
+ creation of, or owns Covered Software.
+
+1.2. "Contributor Version"
+
+ means the combination of the Contributions of others (if any) used by a
+ Contributor and that particular Contributor's Contribution.
+
+1.3. "Contribution"
+
+ means Covered Software of a particular Contributor.
+
+1.4. "Covered Software"
+
+ means Source Code Form to which the initial Contributor has attached the
+ notice in Exhibit A, the Executable Form of such Source Code Form, and
+ Modifications of such Source Code Form, in each case including portions
+ thereof.
+
+1.5. "Incompatible With Secondary Licenses"
+ means
+
+ a. that the initial Contributor has attached the notice described in
+ Exhibit B to the Covered Software; or
+
+ b. that the Covered Software was made available under the terms of
+ version 1.1 or earlier of the License, but not also under the terms of
+ a Secondary License.
+
+1.6. "Executable Form"
+
+ means any form of the work other than Source Code Form.
+
+1.7. "Larger Work"
+
+ means a work that combines Covered Software with other material, in a
+ separate file or files, that is not Covered Software.
+
+1.8. "License"
+
+ means this document.
+
+1.9. "Licensable"
+
+ means having the right to grant, to the maximum extent possible, whether
+ at the time of the initial grant or subsequently, any and all of the
+ rights conveyed by this License.
+
+1.10. "Modifications"
+
+ means any of the following:
+
+ a. any file in Source Code Form that results from an addition to,
+ deletion from, or modification of the contents of Covered Software; or
+
+ b. any new file in Source Code Form that contains any Covered Software.
+
+1.11. "Patent Claims" of a Contributor
+
+ means any patent claim(s), including without limitation, method,
+ process, and apparatus claims, in any patent Licensable by such
+ Contributor that would be infringed, but for the grant of the License,
+ by the making, using, selling, offering for sale, having made, import,
+ or transfer of either its Contributions or its Contributor Version.
+
+1.12. "Secondary License"
+
+ means either the GNU General Public License, Version 2.0, the GNU Lesser
+ General Public License, Version 2.1, the GNU Affero General Public
+ License, Version 3.0, or any later versions of those licenses.
+
+1.13. "Source Code Form"
+
+ means the form of the work preferred for making modifications.
+
+1.14. "You" (or "Your")
+
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, "You" includes any entity that controls, is
+ controlled by, or is under common control with You. For purposes of this
+ definition, "control" means (a) the power, direct or indirect, to cause
+ the direction or management of such entity, whether by contract or
+ otherwise, or (b) ownership of more than fifty percent (50%) of the
+ outstanding shares or beneficial ownership of such entity.
+
+
+2. License Grants and Conditions
+
+2.1. Grants
+
+ Each Contributor hereby grants You a world-wide, royalty-free,
+ non-exclusive license:
+
+ a. under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or
+ as part of a Larger Work; and
+
+ b. under Patent Claims of such Contributor to make, use, sell, offer for
+ sale, have made, import, and otherwise transfer either its
+ Contributions or its Contributor Version.
+
+2.2. Effective Date
+
+ The licenses granted in Section 2.1 with respect to any Contribution
+ become effective for each Contribution on the date the Contributor first
+ distributes such Contribution.
+
+2.3. Limitations on Grant Scope
+
+ The licenses granted in this Section 2 are the only rights granted under
+ this License. No additional rights or licenses will be implied from the
+ distribution or licensing of Covered Software under this License.
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
+ Contributor:
+
+ a. for any code that a Contributor has removed from Covered Software; or
+
+ b. for infringements caused by: (i) Your and any other third party's
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+ c. under Patent Claims infringed by Covered Software in the absence of
+ its Contributions.
+
+ This License does not grant any rights in the trademarks, service marks,
+ or logos of any Contributor (except as may be necessary to comply with
+ the notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+ No Contributor makes additional grants as a result of Your choice to
+ distribute the Covered Software under a subsequent version of this
+ License (see Section 10.2) or under the terms of a Secondary License (if
+ permitted under the terms of Section 3.3).
+
+2.5. Representation
+
+ Each Contributor represents that the Contributor believes its
+ Contributions are its original creation(s) or it has sufficient rights to
+ grant the rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+ This License is not intended to limit any rights You have under
+ applicable copyright doctrines of fair use, fair dealing, or other
+ equivalents.
+
+2.7. Conditions
+
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
+ Section 2.1.
+
+
+3. Responsibilities
+
+3.1. Distribution of Source Form
+
+ All distribution of Covered Software in Source Code Form, including any
+ Modifications that You create or to which You contribute, must be under
+ the terms of this License. You must inform recipients that the Source
+ Code Form of the Covered Software is governed by the terms of this
+ License, and how they can obtain a copy of this License. You may not
+ attempt to alter or restrict the recipients' rights in the Source Code
+ Form.
+
+3.2. Distribution of Executable Form
+
+ If You distribute Covered Software in Executable Form then:
+
+ a. such Covered Software must also be made available in Source Code Form,
+ as described in Section 3.1, and You must inform recipients of the
+ Executable Form how they can obtain a copy of such Source Code Form by
+ reasonable means in a timely manner, at a charge no more than the cost
+ of distribution to the recipient; and
+
+ b. You may distribute such Executable Form under the terms of this
+ License, or sublicense it under different terms, provided that the
+ license for the Executable Form does not attempt to limit or alter the
+ recipients' rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+ You may create and distribute a Larger Work under terms of Your choice,
+ provided that You also comply with the requirements of this License for
+ the Covered Software. If the Larger Work is a combination of Covered
+ Software with a work governed by one or more Secondary Licenses, and the
+ Covered Software is not Incompatible With Secondary Licenses, this
+ License permits You to additionally distribute such Covered Software
+ under the terms of such Secondary License(s), so that the recipient of
+ the Larger Work may, at their option, further distribute the Covered
+ Software under the terms of either this License or such Secondary
+ License(s).
+
+3.4. Notices
+
+ You may not remove or alter the substance of any license notices
+ (including copyright notices, patent notices, disclaimers of warranty, or
+ limitations of liability) contained within the Source Code Form of the
+ Covered Software, except that You may alter any license notices to the
+ extent required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+ You may choose to offer, and to charge a fee for, warranty, support,
+ indemnity or liability obligations to one or more recipients of Covered
+ Software. However, You may do so only on Your own behalf, and not on
+ behalf of any Contributor. You must make it absolutely clear that any
+ such warranty, support, indemnity, or liability obligation is offered by
+ You alone, and You hereby agree to indemnify every Contributor for any
+ liability incurred by such Contributor as a result of warranty, support,
+ indemnity or liability terms You offer. You may include additional
+ disclaimers of warranty and limitations of liability specific to any
+ jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+
+ If it is impossible for You to comply with any of the terms of this License
+ with respect to some or all of the Covered Software due to statute,
+ judicial order, or regulation then You must: (a) comply with the terms of
+ this License to the maximum extent possible; and (b) describe the
+ limitations and the code they affect. Such description must be placed in a
+ text file included with all distributions of the Covered Software under
+ this License. Except to the extent prohibited by statute or regulation,
+ such description must be sufficiently detailed for a recipient of ordinary
+ skill to be able to understand it.
+
+5. Termination
+
+5.1. The rights granted under this License will terminate automatically if You
+ fail to comply with any of its terms. However, if You become compliant,
+ then the rights granted under this License from a particular Contributor
+ are reinstated (a) provisionally, unless and until such Contributor
+ explicitly and finally terminates Your grants, and (b) on an ongoing
+ basis, if such Contributor fails to notify You of the non-compliance by
+ some reasonable means prior to 60 days after You have come back into
+ compliance. Moreover, Your grants from a particular Contributor are
+ reinstated on an ongoing basis if such Contributor notifies You of the
+ non-compliance by some reasonable means, this is the first time You have
+ received notice of non-compliance with this License from such
+ Contributor, and You become compliant prior to 30 days after Your receipt
+ of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+ infringement claim (excluding declaratory judgment actions,
+ counter-claims, and cross-claims) alleging that a Contributor Version
+ directly or indirectly infringes any patent, then the rights granted to
+ You by any and all Contributors for the Covered Software under Section
+ 2.1 of this License shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
+ license agreements (excluding distributors and resellers) which have been
+ validly granted by You or Your distributors under this License prior to
+ termination shall survive termination.
+
+6. Disclaimer of Warranty
+
+ Covered Software is provided under this License on an "as is" basis,
+ without warranty of any kind, either expressed, implied, or statutory,
+ including, without limitation, warranties that the Covered Software is free
+ of defects, merchantable, fit for a particular purpose or non-infringing.
+ The entire risk as to the quality and performance of the Covered Software
+ is with You. Should any Covered Software prove defective in any respect,
+ You (not any Contributor) assume the cost of any necessary servicing,
+ repair, or correction. This disclaimer of warranty constitutes an essential
+ part of this License. No use of any Covered Software is authorized under
+ this License except under this disclaimer.
+
+7. Limitation of Liability
+
+ Under no circumstances and under no legal theory, whether tort (including
+ negligence), contract, or otherwise, shall any Contributor, or anyone who
+ distributes Covered Software as permitted above, be liable to You for any
+ direct, indirect, special, incidental, or consequential damages of any
+ character including, without limitation, damages for lost profits, loss of
+ goodwill, work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses, even if such party shall have been
+ informed of the possibility of such damages. This limitation of liability
+ shall not apply to liability for death or personal injury resulting from
+ such party's negligence to the extent applicable law prohibits such
+ limitation. Some jurisdictions do not allow the exclusion or limitation of
+ incidental or consequential damages, so this exclusion and limitation may
+ not apply to You.
+
+8. Litigation
+
+ Any litigation relating to this License may be brought only in the courts
+ of a jurisdiction where the defendant maintains its principal place of
+ business and such litigation shall be governed by laws of that
+ jurisdiction, without reference to its conflict-of-law provisions. Nothing
+ in this Section shall prevent a party's ability to bring cross-claims or
+ counter-claims.
+
+9. Miscellaneous
+
+ This License represents the complete agreement concerning the subject
+ matter hereof. If any provision of this License is held to be
+ unenforceable, such provision shall be reformed only to the extent
+ necessary to make it enforceable. Any law or regulation which provides that
+ the language of a contract shall be construed against the drafter shall not
+ be used to construe this License against a Contributor.
+
+
+10. Versions of the License
+
+10.1. New Versions
+
+ Mozilla Foundation is the license steward. Except as provided in Section
+ 10.3, no one other than the license steward has the right to modify or
+ publish new versions of this License. Each version will be given a
+ distinguishing version number.
+
+10.2. Effect of New Versions
+
+ You may distribute the Covered Software under the terms of the version
+ of the License under which You originally received the Covered Software,
+ or under the terms of any subsequent version published by the license
+ steward.
+
+10.3. Modified Versions
+
+ If you create software not governed by this License, and you want to
+ create a new license for such software, you may create and use a
+ modified version of this License if you rename the license and remove
+ any references to the name of the license steward (except to note that
+ such modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary
+ Licenses If You choose to distribute Source Code Form that is
+ Incompatible With Secondary Licenses under the terms of this version of
+ the License, the notice described in Exhibit B of this License must be
+ attached.
+
+Exhibit A - Source Code Form License Notice
+
+ This Source Code Form is subject to the
+ terms of the Mozilla Public License, v.
+ 2.0. If a copy of the MPL was not
+ distributed with this file, You can
+ obtain one at
+ http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular file,
+then You may include the notice in a location (such as a LICENSE file in a
+relevant directory) where a recipient would be likely to look for such a
+notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - "Incompatible With Secondary Licenses" Notice
+
+ This Source Code Form is "Incompatible
+ With Secondary Licenses", as defined by
+ the Mozilla Public License, v. 2.0.
\ No newline at end of file
--- /dev/null
+# Yamux
+
+Yamux (Yet another Multiplexer) is a multiplexing library for Golang.
+It relies on an underlying connection to provide reliability
+and ordering, such as TCP or Unix domain sockets, and provides
+stream-oriented multiplexing. It is inspired by SPDY but is not
+interoperable with it.
+
+Yamux features include:
+
+* Bi-directional streams
+ * Streams can be opened by either client or server
+ * Useful for NAT traversal
+ * Server-side push support
+* Flow control
+ * Avoid starvation
+ * Back-pressure to prevent overwhelming a receiver
+* Keep Alives
+ * Enables persistent connections over a load balancer
+* Efficient
+ * Enables thousands of logical streams with low overhead
+
+## Documentation
+
+For complete documentation, see the associated [Godoc](http://godoc.org/github.com/hashicorp/yamux).
+
+## Specification
+
+The full specification for Yamux is provided in the `spec.md` file.
+It can be used as a guide to implementors of interoperable libraries.
+
+## Usage
+
+Using Yamux is remarkably simple:
+
+```go
+
+func client() {
+ // Get a TCP connection
+ conn, err := net.Dial(...)
+ if err != nil {
+ panic(err)
+ }
+
+ // Setup client side of yamux
+ session, err := yamux.Client(conn, nil)
+ if err != nil {
+ panic(err)
+ }
+
+ // Open a new stream
+ stream, err := session.Open()
+ if err != nil {
+ panic(err)
+ }
+
+ // Stream implements net.Conn
+ stream.Write([]byte("ping"))
+}
+
+func server() {
+ // Accept a TCP connection
+ conn, err := listener.Accept()
+ if err != nil {
+ panic(err)
+ }
+
+ // Setup server side of yamux
+ session, err := yamux.Server(conn, nil)
+ if err != nil {
+ panic(err)
+ }
+
+ // Accept a stream
+ stream, err := session.Accept()
+ if err != nil {
+ panic(err)
+ }
+
+ // Listen for a message
+ buf := make([]byte, 4)
+ stream.Read(buf)
+}
+
+```
+
--- /dev/null
+package yamux
+
+import (
+ "fmt"
+ "net"
+)
+
+// hasAddr is used to get the address from the underlying connection
+type hasAddr interface {
+ LocalAddr() net.Addr
+ RemoteAddr() net.Addr
+}
+
+// yamuxAddr is used when we cannot get the underlying address
+type yamuxAddr struct {
+ Addr string
+}
+
+func (*yamuxAddr) Network() string {
+ return "yamux"
+}
+
+func (y *yamuxAddr) String() string {
+ return fmt.Sprintf("yamux:%s", y.Addr)
+}
+
+// Addr is used to get the address of the listener.
+func (s *Session) Addr() net.Addr {
+ return s.LocalAddr()
+}
+
+// LocalAddr is used to get the local address of the
+// underlying connection.
+func (s *Session) LocalAddr() net.Addr {
+ addr, ok := s.conn.(hasAddr)
+ if !ok {
+ return &yamuxAddr{"local"}
+ }
+ return addr.LocalAddr()
+}
+
+// RemoteAddr is used to get the address of remote end
+// of the underlying connection
+func (s *Session) RemoteAddr() net.Addr {
+ addr, ok := s.conn.(hasAddr)
+ if !ok {
+ return &yamuxAddr{"remote"}
+ }
+ return addr.RemoteAddr()
+}
+
+// LocalAddr returns the local address
+func (s *Stream) LocalAddr() net.Addr {
+ return s.session.LocalAddr()
+}
+
+// LocalAddr returns the remote address
+func (s *Stream) RemoteAddr() net.Addr {
+ return s.session.RemoteAddr()
+}
--- /dev/null
+package yamux
+
+import (
+ "testing"
+)
+
+func BenchmarkPing(b *testing.B) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ for i := 0; i < b.N; i++ {
+ rtt, err := client.Ping()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ if rtt == 0 {
+ b.Fatalf("bad: %v", rtt)
+ }
+ }
+}
+
+func BenchmarkAccept(b *testing.B) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ go func() {
+ for i := 0; i < b.N; i++ {
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ stream.Close()
+ }
+ }()
+
+ for i := 0; i < b.N; i++ {
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ stream.Close()
+ }
+}
+
+func BenchmarkSendRecv(b *testing.B) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ sendBuf := make([]byte, 512)
+ recvBuf := make([]byte, 512)
+
+ doneCh := make(chan struct{})
+ go func() {
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ defer stream.Close()
+ for i := 0; i < b.N; i++ {
+ if _, err := stream.Read(recvBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+ close(doneCh)
+ }()
+
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ for i := 0; i < b.N; i++ {
+ if _, err := stream.Write(sendBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+ <-doneCh
+}
+
+func BenchmarkSendRecvLarge(b *testing.B) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+ const sendSize = 512 * 1024 * 1024
+ const recvSize = 4 * 1024
+
+ sendBuf := make([]byte, sendSize)
+ recvBuf := make([]byte, recvSize)
+
+ b.ResetTimer()
+ recvDone := make(chan struct{})
+
+ go func() {
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ defer stream.Close()
+ for i := 0; i < b.N; i++ {
+ for j := 0; j < sendSize/recvSize; j++ {
+ if _, err := stream.Read(recvBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+ }
+ close(recvDone)
+ }()
+
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ for i := 0; i < b.N; i++ {
+ if _, err := stream.Write(sendBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+ <-recvDone
+}
--- /dev/null
+package yamux
+
+import (
+ "encoding/binary"
+ "fmt"
+)
+
+var (
+ // ErrInvalidVersion means we received a frame with an
+ // invalid version
+ ErrInvalidVersion = fmt.Errorf("invalid protocol version")
+
+ // ErrInvalidMsgType means we received a frame with an
+ // invalid message type
+ ErrInvalidMsgType = fmt.Errorf("invalid msg type")
+
+ // ErrSessionShutdown is used if there is a shutdown during
+ // an operation
+ ErrSessionShutdown = fmt.Errorf("session shutdown")
+
+ // ErrStreamsExhausted is returned if we have no more
+ // stream ids to issue
+ ErrStreamsExhausted = fmt.Errorf("streams exhausted")
+
+ // ErrDuplicateStream is used if a duplicate stream is
+ // opened inbound
+ ErrDuplicateStream = fmt.Errorf("duplicate stream initiated")
+
+ // ErrReceiveWindowExceeded indicates the window was exceeded
+ ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
+
+ // ErrTimeout is used when we reach an IO deadline
+ ErrTimeout = fmt.Errorf("i/o deadline reached")
+
+ // ErrStreamClosed is returned when using a closed stream
+ ErrStreamClosed = fmt.Errorf("stream closed")
+
+ // ErrUnexpectedFlag is set when we get an unexpected flag
+ ErrUnexpectedFlag = fmt.Errorf("unexpected flag")
+
+ // ErrRemoteGoAway is used when we get a go away from the other side
+ ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
+
+ // ErrConnectionReset is sent if a stream is reset. This can happen
+ // if the backlog is exceeded, or if there was a remote GoAway.
+ ErrConnectionReset = fmt.Errorf("connection reset")
+
+ // ErrConnectionWriteTimeout indicates that we hit the "safety valve"
+ // timeout writing to the underlying stream connection.
+ ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout")
+
+ // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
+ ErrKeepAliveTimeout = fmt.Errorf("keepalive timeout")
+)
+
+const (
+ // protoVersion is the only version we support
+ protoVersion uint8 = 0
+)
+
+const (
+ // Data is used for data frames. They are followed
+ // by length bytes worth of payload.
+ typeData uint8 = iota
+
+ // WindowUpdate is used to change the window of
+ // a given stream. The length indicates the delta
+ // update to the window.
+ typeWindowUpdate
+
+ // Ping is sent as a keep-alive or to measure
+ // the RTT. The StreamID and Length value are echoed
+ // back in the response.
+ typePing
+
+ // GoAway is sent to terminate a session. The StreamID
+ // should be 0 and the length is an error code.
+ typeGoAway
+)
+
+const (
+ // SYN is sent to signal a new stream. May
+ // be sent with a data payload
+ flagSYN uint16 = 1 << iota
+
+ // ACK is sent to acknowledge a new stream. May
+ // be sent with a data payload
+ flagACK
+
+ // FIN is sent to half-close the given stream.
+ // May be sent with a data payload.
+ flagFIN
+
+ // RST is used to hard close a given stream.
+ flagRST
+)
+
+const (
+ // initialStreamWindow is the initial stream window size
+ initialStreamWindow uint32 = 256 * 1024
+)
+
+const (
+ // goAwayNormal is sent on a normal termination
+ goAwayNormal uint32 = iota
+
+ // goAwayProtoErr sent on a protocol error
+ goAwayProtoErr
+
+ // goAwayInternalErr sent on an internal error
+ goAwayInternalErr
+)
+
+const (
+ sizeOfVersion = 1
+ sizeOfType = 1
+ sizeOfFlags = 2
+ sizeOfStreamID = 4
+ sizeOfLength = 4
+ headerSize = sizeOfVersion + sizeOfType + sizeOfFlags +
+ sizeOfStreamID + sizeOfLength
+)
+
+type header []byte
+
+func (h header) Version() uint8 {
+ return h[0]
+}
+
+func (h header) MsgType() uint8 {
+ return h[1]
+}
+
+func (h header) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[2:4])
+}
+
+func (h header) StreamID() uint32 {
+ return binary.BigEndian.Uint32(h[4:8])
+}
+
+func (h header) Length() uint32 {
+ return binary.BigEndian.Uint32(h[8:12])
+}
+
+func (h header) String() string {
+ return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d",
+ h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length())
+}
+
+func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) {
+ h[0] = protoVersion
+ h[1] = msgType
+ binary.BigEndian.PutUint16(h[2:4], flags)
+ binary.BigEndian.PutUint32(h[4:8], streamID)
+ binary.BigEndian.PutUint32(h[8:12], length)
+}
--- /dev/null
+package yamux
+
+import (
+ "testing"
+)
+
+func TestConst(t *testing.T) {
+ if protoVersion != 0 {
+ t.Fatalf("bad: %v", protoVersion)
+ }
+
+ if typeData != 0 {
+ t.Fatalf("bad: %v", typeData)
+ }
+ if typeWindowUpdate != 1 {
+ t.Fatalf("bad: %v", typeWindowUpdate)
+ }
+ if typePing != 2 {
+ t.Fatalf("bad: %v", typePing)
+ }
+ if typeGoAway != 3 {
+ t.Fatalf("bad: %v", typeGoAway)
+ }
+
+ if flagSYN != 1 {
+ t.Fatalf("bad: %v", flagSYN)
+ }
+ if flagACK != 2 {
+ t.Fatalf("bad: %v", flagACK)
+ }
+ if flagFIN != 4 {
+ t.Fatalf("bad: %v", flagFIN)
+ }
+ if flagRST != 8 {
+ t.Fatalf("bad: %v", flagRST)
+ }
+
+ if goAwayNormal != 0 {
+ t.Fatalf("bad: %v", goAwayNormal)
+ }
+ if goAwayProtoErr != 1 {
+ t.Fatalf("bad: %v", goAwayProtoErr)
+ }
+ if goAwayInternalErr != 2 {
+ t.Fatalf("bad: %v", goAwayInternalErr)
+ }
+
+ if headerSize != 12 {
+ t.Fatalf("bad header size")
+ }
+}
+
+func TestEncodeDecode(t *testing.T) {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321)
+
+ if hdr.Version() != protoVersion {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.MsgType() != typeWindowUpdate {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.Flags() != flagACK|flagRST {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.StreamID() != 1234 {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.Length() != 4321 {
+ t.Fatalf("bad: %v", hdr)
+ }
+}
--- /dev/null
+package yamux
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "time"
+)
+
+// Config is used to tune the Yamux session
+type Config struct {
+ // AcceptBacklog is used to limit how many streams may be
+ // waiting an accept.
+ AcceptBacklog int
+
+ // EnableKeepalive is used to do a period keep alive
+ // messages using a ping.
+ EnableKeepAlive bool
+
+ // KeepAliveInterval is how often to perform the keep alive
+ KeepAliveInterval time.Duration
+
+ // ConnectionWriteTimeout is meant to be a "safety valve" timeout after
+ // we which will suspect a problem with the underlying connection and
+ // close it. This is only applied to writes, where's there's generally
+ // an expectation that things will move along quickly.
+ ConnectionWriteTimeout time.Duration
+
+ // MaxStreamWindowSize is used to control the maximum
+ // window size that we allow for a stream.
+ MaxStreamWindowSize uint32
+
+ // LogOutput is used to control the log destination
+ LogOutput io.Writer
+}
+
+// DefaultConfig is used to return a default configuration
+func DefaultConfig() *Config {
+ return &Config{
+ AcceptBacklog: 256,
+ EnableKeepAlive: true,
+ KeepAliveInterval: 30 * time.Second,
+ ConnectionWriteTimeout: 10 * time.Second,
+ MaxStreamWindowSize: initialStreamWindow,
+ LogOutput: os.Stderr,
+ }
+}
+
+// VerifyConfig is used to verify the sanity of configuration
+func VerifyConfig(config *Config) error {
+ if config.AcceptBacklog <= 0 {
+ return fmt.Errorf("backlog must be positive")
+ }
+ if config.KeepAliveInterval == 0 {
+ return fmt.Errorf("keep-alive interval must be positive")
+ }
+ if config.MaxStreamWindowSize < initialStreamWindow {
+ return fmt.Errorf("MaxStreamWindowSize must be larger than %d", initialStreamWindow)
+ }
+ return nil
+}
+
+// Server is used to initialize a new server-side connection.
+// There must be at most one server-side connection. If a nil config is
+// provided, the DefaultConfiguration will be used.
+func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, false), nil
+}
+
+// Client is used to initialize a new client-side connection.
+// There must be at most one client-side connection.
+func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, true), nil
+}
--- /dev/null
+package yamux
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math"
+ "net"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Session is used to wrap a reliable ordered connection and to
+// multiplex it into multiple streams.
+type Session struct {
+ // remoteGoAway indicates the remote side does
+ // not want futher connections. Must be first for alignment.
+ remoteGoAway int32
+
+ // localGoAway indicates that we should stop
+ // accepting futher connections. Must be first for alignment.
+ localGoAway int32
+
+ // nextStreamID is the next stream we should
+ // send. This depends if we are a client/server.
+ nextStreamID uint32
+
+ // config holds our configuration
+ config *Config
+
+ // logger is used for our logs
+ logger *log.Logger
+
+ // conn is the underlying connection
+ conn io.ReadWriteCloser
+
+ // bufRead is a buffered reader
+ bufRead *bufio.Reader
+
+ // pings is used to track inflight pings
+ pings map[uint32]chan struct{}
+ pingID uint32
+ pingLock sync.Mutex
+
+ // streams maps a stream id to a stream, and inflight has an entry
+ // for any outgoing stream that has not yet been established. Both are
+ // protected by streamLock.
+ streams map[uint32]*Stream
+ inflight map[uint32]struct{}
+ streamLock sync.Mutex
+
+ // synCh acts like a semaphore. It is sized to the AcceptBacklog which
+ // is assumed to be symmetric between the client and server. This allows
+ // the client to avoid exceeding the backlog and instead blocks the open.
+ synCh chan struct{}
+
+ // acceptCh is used to pass ready streams to the client
+ acceptCh chan *Stream
+
+ // sendCh is used to mark a stream as ready to send,
+ // or to send a header out directly.
+ sendCh chan sendReady
+
+ // recvDoneCh is closed when recv() exits to avoid a race
+ // between stream registration and stream shutdown
+ recvDoneCh chan struct{}
+
+ // shutdown is used to safely close a session
+ shutdown bool
+ shutdownErr error
+ shutdownCh chan struct{}
+ shutdownLock sync.Mutex
+}
+
+// sendReady is used to either mark a stream as ready
+// or to directly send a header
+type sendReady struct {
+ Hdr []byte
+ Body io.Reader
+ Err chan error
+}
+
+// newSession is used to construct a new session
+func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
+ s := &Session{
+ config: config,
+ logger: log.New(config.LogOutput, "", log.LstdFlags),
+ conn: conn,
+ bufRead: bufio.NewReader(conn),
+ pings: make(map[uint32]chan struct{}),
+ streams: make(map[uint32]*Stream),
+ inflight: make(map[uint32]struct{}),
+ synCh: make(chan struct{}, config.AcceptBacklog),
+ acceptCh: make(chan *Stream, config.AcceptBacklog),
+ sendCh: make(chan sendReady, 64),
+ recvDoneCh: make(chan struct{}),
+ shutdownCh: make(chan struct{}),
+ }
+ if client {
+ s.nextStreamID = 1
+ } else {
+ s.nextStreamID = 2
+ }
+ go s.recv()
+ go s.send()
+ if config.EnableKeepAlive {
+ go s.keepalive()
+ }
+ return s
+}
+
+// IsClosed does a safe check to see if we have shutdown
+func (s *Session) IsClosed() bool {
+ select {
+ case <-s.shutdownCh:
+ return true
+ default:
+ return false
+ }
+}
+
+// CloseChan returns a read-only channel which is closed as
+// soon as the session is closed.
+func (s *Session) CloseChan() <-chan struct{} {
+ return s.shutdownCh
+}
+
+// NumStreams returns the number of currently open streams
+func (s *Session) NumStreams() int {
+ s.streamLock.Lock()
+ num := len(s.streams)
+ s.streamLock.Unlock()
+ return num
+}
+
+// Open is used to create a new stream as a net.Conn
+func (s *Session) Open() (net.Conn, error) {
+ conn, err := s.OpenStream()
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+// OpenStream is used to create a new stream
+func (s *Session) OpenStream() (*Stream, error) {
+ if s.IsClosed() {
+ return nil, ErrSessionShutdown
+ }
+ if atomic.LoadInt32(&s.remoteGoAway) == 1 {
+ return nil, ErrRemoteGoAway
+ }
+
+ // Block if we have too many inflight SYNs
+ select {
+ case s.synCh <- struct{}{}:
+ case <-s.shutdownCh:
+ return nil, ErrSessionShutdown
+ }
+
+GET_ID:
+ // Get an ID, and check for stream exhaustion
+ id := atomic.LoadUint32(&s.nextStreamID)
+ if id >= math.MaxUint32-1 {
+ return nil, ErrStreamsExhausted
+ }
+ if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
+ goto GET_ID
+ }
+
+ // Register the stream
+ stream := newStream(s, id, streamInit)
+ s.streamLock.Lock()
+ s.streams[id] = stream
+ s.inflight[id] = struct{}{}
+ s.streamLock.Unlock()
+
+ // Send the window update to create
+ if err := stream.sendWindowUpdate(); err != nil {
+ select {
+ case <-s.synCh:
+ default:
+ s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
+ }
+ return nil, err
+ }
+ return stream, nil
+}
+
+// Accept is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) Accept() (net.Conn, error) {
+ conn, err := s.AcceptStream()
+ if err != nil {
+ return nil, err
+ }
+ return conn, err
+}
+
+// AcceptStream is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) AcceptStream() (*Stream, error) {
+ select {
+ case stream := <-s.acceptCh:
+ if err := stream.sendWindowUpdate(); err != nil {
+ return nil, err
+ }
+ return stream, nil
+ case <-s.shutdownCh:
+ return nil, s.shutdownErr
+ }
+}
+
+// Close is used to close the session and all streams.
+// Attempts to send a GoAway before closing the connection.
+func (s *Session) Close() error {
+ s.shutdownLock.Lock()
+ defer s.shutdownLock.Unlock()
+
+ if s.shutdown {
+ return nil
+ }
+ s.shutdown = true
+ if s.shutdownErr == nil {
+ s.shutdownErr = ErrSessionShutdown
+ }
+ close(s.shutdownCh)
+ s.conn.Close()
+ <-s.recvDoneCh
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+ for _, stream := range s.streams {
+ stream.forceClose()
+ }
+ return nil
+}
+
+// exitErr is used to handle an error that is causing the
+// session to terminate.
+func (s *Session) exitErr(err error) {
+ s.shutdownLock.Lock()
+ if s.shutdownErr == nil {
+ s.shutdownErr = err
+ }
+ s.shutdownLock.Unlock()
+ s.Close()
+}
+
+// GoAway can be used to prevent accepting further
+// connections. It does not close the underlying conn.
+func (s *Session) GoAway() error {
+ return s.waitForSend(s.goAway(goAwayNormal), nil)
+}
+
+// goAway is used to send a goAway message
+func (s *Session) goAway(reason uint32) header {
+ atomic.SwapInt32(&s.localGoAway, 1)
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeGoAway, 0, 0, reason)
+ return hdr
+}
+
+// Ping is used to measure the RTT response time
+func (s *Session) Ping() (time.Duration, error) {
+ // Get a channel for the ping
+ ch := make(chan struct{})
+
+ // Get a new ping id, mark as pending
+ s.pingLock.Lock()
+ id := s.pingID
+ s.pingID++
+ s.pings[id] = ch
+ s.pingLock.Unlock()
+
+ // Send the ping request
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagSYN, 0, id)
+ if err := s.waitForSend(hdr, nil); err != nil {
+ return 0, err
+ }
+
+ // Wait for a response
+ start := time.Now()
+ select {
+ case <-ch:
+ case <-time.After(s.config.ConnectionWriteTimeout):
+ s.pingLock.Lock()
+ delete(s.pings, id) // Ignore it if a response comes later.
+ s.pingLock.Unlock()
+ return 0, ErrTimeout
+ case <-s.shutdownCh:
+ return 0, ErrSessionShutdown
+ }
+
+ // Compute the RTT
+ return time.Now().Sub(start), nil
+}
+
+// keepalive is a long running goroutine that periodically does
+// a ping to keep the connection alive.
+func (s *Session) keepalive() {
+ for {
+ select {
+ case <-time.After(s.config.KeepAliveInterval):
+ _, err := s.Ping()
+ if err != nil {
+ if err != ErrSessionShutdown {
+ s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
+ s.exitErr(ErrKeepAliveTimeout)
+ }
+ return
+ }
+ case <-s.shutdownCh:
+ return
+ }
+ }
+}
+
+// waitForSendErr waits to send a header, checking for a potential shutdown
+func (s *Session) waitForSend(hdr header, body io.Reader) error {
+ errCh := make(chan error, 1)
+ return s.waitForSendErr(hdr, body, errCh)
+}
+
+// waitForSendErr waits to send a header with optional data, checking for a
+// potential shutdown. Since there's the expectation that sends can happen
+// in a timely manner, we enforce the connection write timeout here.
+func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
+ t := timerPool.Get()
+ timer := t.(*time.Timer)
+ timer.Reset(s.config.ConnectionWriteTimeout)
+ defer func() {
+ timer.Stop()
+ select {
+ case <-timer.C:
+ default:
+ }
+ timerPool.Put(t)
+ }()
+
+ ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
+ select {
+ case s.sendCh <- ready:
+ case <-s.shutdownCh:
+ return ErrSessionShutdown
+ case <-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+
+ select {
+ case err := <-errCh:
+ return err
+ case <-s.shutdownCh:
+ return ErrSessionShutdown
+ case <-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+}
+
+// sendNoWait does a send without waiting. Since there's the expectation that
+// the send happens right here, we enforce the connection write timeout if we
+// can't queue the header to be sent.
+func (s *Session) sendNoWait(hdr header) error {
+ t := timerPool.Get()
+ timer := t.(*time.Timer)
+ timer.Reset(s.config.ConnectionWriteTimeout)
+ defer func() {
+ timer.Stop()
+ select {
+ case <-timer.C:
+ default:
+ }
+ timerPool.Put(t)
+ }()
+
+ select {
+ case s.sendCh <- sendReady{Hdr: hdr}:
+ return nil
+ case <-s.shutdownCh:
+ return ErrSessionShutdown
+ case <-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+}
+
+// send is a long running goroutine that sends data
+func (s *Session) send() {
+ for {
+ select {
+ case ready := <-s.sendCh:
+ // Send a header if ready
+ if ready.Hdr != nil {
+ sent := 0
+ for sent < len(ready.Hdr) {
+ n, err := s.conn.Write(ready.Hdr[sent:])
+ if err != nil {
+ s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
+ asyncSendErr(ready.Err, err)
+ s.exitErr(err)
+ return
+ }
+ sent += n
+ }
+ }
+
+ // Send data from a body if given
+ if ready.Body != nil {
+ _, err := io.Copy(s.conn, ready.Body)
+ if err != nil {
+ s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
+ asyncSendErr(ready.Err, err)
+ s.exitErr(err)
+ return
+ }
+ }
+
+ // No error, successful send
+ asyncSendErr(ready.Err, nil)
+ case <-s.shutdownCh:
+ return
+ }
+ }
+}
+
+// recv is a long running goroutine that accepts new data
+func (s *Session) recv() {
+ if err := s.recvLoop(); err != nil {
+ s.exitErr(err)
+ }
+}
+
+// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
+var (
+ handlers = []func(*Session, header) error{
+ typeData: (*Session).handleStreamMessage,
+ typeWindowUpdate: (*Session).handleStreamMessage,
+ typePing: (*Session).handlePing,
+ typeGoAway: (*Session).handleGoAway,
+ }
+)
+
+// recvLoop continues to receive data until a fatal error is encountered
+func (s *Session) recvLoop() error {
+ defer close(s.recvDoneCh)
+ hdr := header(make([]byte, headerSize))
+ for {
+ // Read the header
+ if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
+ if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
+ s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
+ }
+ return err
+ }
+
+ // Verify the version
+ if hdr.Version() != protoVersion {
+ s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
+ return ErrInvalidVersion
+ }
+
+ mt := hdr.MsgType()
+ if mt < typeData || mt > typeGoAway {
+ return ErrInvalidMsgType
+ }
+
+ if err := handlers[mt](s, hdr); err != nil {
+ return err
+ }
+ }
+}
+
+// handleStreamMessage handles either a data or window update frame
+func (s *Session) handleStreamMessage(hdr header) error {
+ // Check for a new stream creation
+ id := hdr.StreamID()
+ flags := hdr.Flags()
+ if flags&flagSYN == flagSYN {
+ if err := s.incomingStream(id); err != nil {
+ return err
+ }
+ }
+
+ // Get the stream
+ s.streamLock.Lock()
+ stream := s.streams[id]
+ s.streamLock.Unlock()
+
+ // If we do not have a stream, likely we sent a RST
+ if stream == nil {
+ // Drain any data on the wire
+ if hdr.MsgType() == typeData && hdr.Length() > 0 {
+ s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
+ if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
+ s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
+ return nil
+ }
+ } else {
+ s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
+ }
+ return nil
+ }
+
+ // Check if this is a window update
+ if hdr.MsgType() == typeWindowUpdate {
+ if err := stream.incrSendWindow(hdr, flags); err != nil {
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+ }
+ return err
+ }
+ return nil
+ }
+
+ // Read the new data
+ if err := stream.readData(hdr, flags, s.bufRead); err != nil {
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+ }
+ return err
+ }
+ return nil
+}
+
+// handlePing is invokde for a typePing frame
+func (s *Session) handlePing(hdr header) error {
+ flags := hdr.Flags()
+ pingID := hdr.Length()
+
+ // Check if this is a query, respond back in a separate context so we
+ // don't interfere with the receiving thread blocking for the write.
+ if flags&flagSYN == flagSYN {
+ go func() {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagACK, 0, pingID)
+ if err := s.sendNoWait(hdr); err != nil {
+ s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
+ }
+ }()
+ return nil
+ }
+
+ // Handle a response
+ s.pingLock.Lock()
+ ch := s.pings[pingID]
+ if ch != nil {
+ delete(s.pings, pingID)
+ close(ch)
+ }
+ s.pingLock.Unlock()
+ return nil
+}
+
+// handleGoAway is invokde for a typeGoAway frame
+func (s *Session) handleGoAway(hdr header) error {
+ code := hdr.Length()
+ switch code {
+ case goAwayNormal:
+ atomic.SwapInt32(&s.remoteGoAway, 1)
+ case goAwayProtoErr:
+ s.logger.Printf("[ERR] yamux: received protocol error go away")
+ return fmt.Errorf("yamux protocol error")
+ case goAwayInternalErr:
+ s.logger.Printf("[ERR] yamux: received internal error go away")
+ return fmt.Errorf("remote yamux internal error")
+ default:
+ s.logger.Printf("[ERR] yamux: received unexpected go away")
+ return fmt.Errorf("unexpected go away received")
+ }
+ return nil
+}
+
+// incomingStream is used to create a new incoming stream
+func (s *Session) incomingStream(id uint32) error {
+ // Reject immediately if we are doing a go away
+ if atomic.LoadInt32(&s.localGoAway) == 1 {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeWindowUpdate, flagRST, id, 0)
+ return s.sendNoWait(hdr)
+ }
+
+ // Allocate a new stream
+ stream := newStream(s, id, streamSYNReceived)
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+
+ // Check if stream already exists
+ if _, ok := s.streams[id]; ok {
+ s.logger.Printf("[ERR] yamux: duplicate stream declared")
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+ }
+ return ErrDuplicateStream
+ }
+
+ // Register the stream
+ s.streams[id] = stream
+
+ // Check if we've exceeded the backlog
+ select {
+ case s.acceptCh <- stream:
+ return nil
+ default:
+ // Backlog exceeded! RST the stream
+ s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
+ delete(s.streams, id)
+ stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
+ return s.sendNoWait(stream.sendHdr)
+ }
+}
+
+// closeStream is used to close a stream once both sides have
+// issued a close. If there was an in-flight SYN and the stream
+// was not yet established, then this will give the credit back.
+func (s *Session) closeStream(id uint32) {
+ s.streamLock.Lock()
+ if _, ok := s.inflight[id]; ok {
+ select {
+ case <-s.synCh:
+ default:
+ s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
+ }
+ }
+ delete(s.streams, id)
+ s.streamLock.Unlock()
+}
+
+// establishStream is used to mark a stream that was in the
+// SYN Sent state as established.
+func (s *Session) establishStream(id uint32) {
+ s.streamLock.Lock()
+ if _, ok := s.inflight[id]; ok {
+ delete(s.inflight, id)
+ } else {
+ s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
+ }
+ select {
+ case <-s.synCh:
+ default:
+ s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
+ }
+ s.streamLock.Unlock()
+}
--- /dev/null
+package yamux
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+type logCapture struct{ bytes.Buffer }
+
+func (l *logCapture) logs() []string {
+ return strings.Split(strings.TrimSpace(l.String()), "\n")
+}
+
+func (l *logCapture) match(expect []string) bool {
+ return reflect.DeepEqual(l.logs(), expect)
+}
+
+func captureLogs(s *Session) *logCapture {
+ buf := new(logCapture)
+ s.logger = log.New(buf, "", 0)
+ return buf
+}
+
+type pipeConn struct {
+ reader *io.PipeReader
+ writer *io.PipeWriter
+ writeBlocker sync.Mutex
+}
+
+func (p *pipeConn) Read(b []byte) (int, error) {
+ return p.reader.Read(b)
+}
+
+func (p *pipeConn) Write(b []byte) (int, error) {
+ p.writeBlocker.Lock()
+ defer p.writeBlocker.Unlock()
+ return p.writer.Write(b)
+}
+
+func (p *pipeConn) Close() error {
+ p.reader.Close()
+ return p.writer.Close()
+}
+
+func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
+ read1, write1 := io.Pipe()
+ read2, write2 := io.Pipe()
+ conn1 := &pipeConn{reader: read1, writer: write2}
+ conn2 := &pipeConn{reader: read2, writer: write1}
+ return conn1, conn2
+}
+
+func testConf() *Config {
+ conf := DefaultConfig()
+ conf.AcceptBacklog = 64
+ conf.KeepAliveInterval = 100 * time.Millisecond
+ conf.ConnectionWriteTimeout = 250 * time.Millisecond
+ return conf
+}
+
+func testConfNoKeepAlive() *Config {
+ conf := testConf()
+ conf.EnableKeepAlive = false
+ return conf
+}
+
+func testClientServer() (*Session, *Session) {
+ return testClientServerConfig(testConf())
+}
+
+func testClientServerConfig(conf *Config) (*Session, *Session) {
+ conn1, conn2 := testConn()
+ client, _ := Client(conn1, conf)
+ server, _ := Server(conn2, conf)
+ return client, server
+}
+
+func TestPing(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ rtt, err := client.Ping()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if rtt == 0 {
+ t.Fatalf("bad: %v", rtt)
+ }
+
+ rtt, err = server.Ping()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if rtt == 0 {
+ t.Fatalf("bad: %v", rtt)
+ }
+}
+
+func TestPing_Timeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ // Prevent the client from responding
+ clientConn := client.conn.(*pipeConn)
+ clientConn.writeBlocker.Lock()
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := server.Ping() // Ping via the server session
+ errCh <- err
+ }()
+
+ select {
+ case err := <-errCh:
+ if err != ErrTimeout {
+ t.Fatalf("err: %v", err)
+ }
+ case <-time.After(client.config.ConnectionWriteTimeout * 2):
+ t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
+ }
+
+ // Verify that we recover, even if we gave up
+ clientConn.writeBlocker.Unlock()
+
+ go func() {
+ _, err := server.Ping() // Ping via the server session
+ errCh <- err
+ }()
+
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ case <-time.After(client.config.ConnectionWriteTimeout):
+ t.Fatalf("timeout")
+ }
+}
+
+func TestCloseBeforeAck(t *testing.T) {
+ cfg := testConf()
+ cfg.AcceptBacklog = 8
+ client, server := testClientServerConfig(cfg)
+
+ defer client.Close()
+ defer server.Close()
+
+ for i := 0; i < 8; i++ {
+ s, err := client.OpenStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+
+ for i := 0; i < 8; i++ {
+ s, err := server.AcceptStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ s, err := client.OpenStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(time.Second * 5):
+ t.Fatal("timed out trying to open stream")
+ }
+}
+
+func TestAccept(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ if client.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+
+ wg := &sync.WaitGroup{}
+ wg.Add(4)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 1 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 2 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 2 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 1 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+
+ select {
+ case <-doneCh:
+ case <-time.After(time.Second):
+ panic("timeout")
+ }
+}
+
+func TestNonNilInterface(t *testing.T) {
+ _, server := testClientServer()
+ server.Close()
+
+ conn, err := server.Accept()
+ if err != nil && conn != nil {
+ t.Error("bad: accept should return a connection of nil value")
+ }
+
+ conn, err = server.Open()
+ if err != nil && conn != nil {
+ t.Error("bad: open should return a connection of nil value")
+ }
+}
+
+func TestSendData_Small(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ if server.NumStreams() != 1 {
+ t.Fatalf("bad")
+ }
+
+ buf := make([]byte, 4)
+ for i := 0; i < 1000; i++ {
+ n, err := stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("short read: %d", n)
+ }
+ if string(buf) != "test" {
+ t.Fatalf("bad: %s", buf)
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ if client.NumStreams() != 1 {
+ t.Fatalf("bad")
+ }
+
+ for i := 0; i < 1000; i++ {
+ n, err := stream.Write([]byte("test"))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("short write %d", n)
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case <-doneCh:
+ case <-time.After(time.Second):
+ panic("timeout")
+ }
+
+ if client.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+}
+
+func TestSendData_Large(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ const (
+ sendSize = 250 * 1024 * 1024
+ recvSize = 4 * 1024
+ )
+
+ data := make([]byte, sendSize)
+ for idx := range data {
+ data[idx] = byte(idx % 256)
+ }
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ var sz int
+ buf := make([]byte, recvSize)
+ for i := 0; i < sendSize/recvSize; i++ {
+ n, err := stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != recvSize {
+ t.Fatalf("short read: %d", n)
+ }
+ sz += n
+ for idx := range buf {
+ if buf[idx] != byte(idx%256) {
+ t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
+ }
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ n, err := stream.Write(data)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != len(data) {
+ t.Fatalf("short write %d", n)
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case <-doneCh:
+ case <-time.After(5 * time.Second):
+ panic("timeout")
+ }
+}
+
+func TestGoAway(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ if err := server.GoAway(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ _, err := client.Open()
+ if err != ErrRemoteGoAway {
+ t.Fatalf("err: %v", err)
+ }
+}
+
+func TestManyStreams(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &sync.WaitGroup{}
+
+ acceptor := func(i int) {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 512)
+ for {
+ n, err := stream.Read(buf)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n == 0 {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ }
+ sender := func(i int) {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ msg := fmt.Sprintf("%08d", i)
+ for i := 0; i < 1000; i++ {
+ n, err := stream.Write([]byte(msg))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != len(msg) {
+ t.Fatalf("short write %d", n)
+ }
+ }
+ }
+
+ for i := 0; i < 50; i++ {
+ wg.Add(2)
+ go acceptor(i)
+ go sender(i)
+ }
+
+ wg.Wait()
+}
+
+func TestManyStreams_PingPong(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &sync.WaitGroup{}
+
+ ping := []byte("ping")
+ pong := []byte("pong")
+
+ acceptor := func(i int) {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ for {
+ // Read the 'ping'
+ n, err := stream.Read(buf)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("err: %v", err)
+ }
+ if !bytes.Equal(buf, ping) {
+ t.Fatalf("bad: %s", buf)
+ }
+
+ // Shrink the internal buffer!
+ stream.Shrink()
+
+ // Write out the 'pong'
+ n, err = stream.Write(pong)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ }
+ sender := func(i int) {
+ defer wg.Done()
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ for i := 0; i < 1000; i++ {
+ // Send the 'ping'
+ n, err := stream.Write(ping)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("short write %d", n)
+ }
+
+ // Read the 'pong'
+ n, err = stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("err: %v", err)
+ }
+ if !bytes.Equal(buf, pong) {
+ t.Fatalf("bad: %s", buf)
+ }
+
+ // Shrink the buffer
+ stream.Shrink()
+ }
+ }
+
+ for i := 0; i < 50; i++ {
+ wg.Add(2)
+ go acceptor(i)
+ go sender(i)
+ }
+
+ wg.Wait()
+}
+
+func TestHalfClose(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if _, err = stream.Write([]byte("a")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ stream2.Close() // Half close
+
+ buf := make([]byte, 4)
+ n, err := stream2.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 1 {
+ t.Fatalf("bad: %v", n)
+ }
+
+ // Send more
+ if _, err = stream.Write([]byte("bcd")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ stream.Close()
+
+ // Read after close
+ n, err = stream2.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 3 {
+ t.Fatalf("bad: %v", n)
+ }
+
+ // EOF after close
+ n, err = stream2.Read(buf)
+ if err != io.EOF {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("bad: %v", n)
+ }
+}
+
+func TestReadDeadline(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ buf := make([]byte, 4)
+ if _, err := stream.Read(buf); err != ErrTimeout {
+ t.Fatalf("err: %v", err)
+ }
+}
+
+func TestWriteDeadline(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ buf := make([]byte, 512)
+ for i := 0; i < int(initialStreamWindow); i++ {
+ _, err := stream.Write(buf)
+ if err != nil && err == ErrTimeout {
+ return
+ } else if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ t.Fatalf("Expected timeout")
+}
+
+func TestBacklogExceeded(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ // Fill the backlog
+ max := client.config.AcceptBacklog
+ for i := 0; i < max; i++ {
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ if _, err := stream.Write([]byte("foo")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+
+ // Attempt to open a new stream
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := client.Open()
+ errCh <- err
+ }()
+
+ // Shutdown the server
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ server.Close()
+ }()
+
+ select {
+ case err := <-errCh:
+ if err == nil {
+ t.Fatalf("open should fail")
+ }
+ case <-time.After(time.Second):
+ t.Fatalf("timeout")
+ }
+}
+
+func TestKeepAlive(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ time.Sleep(200 * time.Millisecond)
+
+ // Ping value should increase
+ client.pingLock.Lock()
+ defer client.pingLock.Unlock()
+ if client.pingID == 0 {
+ t.Fatalf("should ping")
+ }
+
+ server.pingLock.Lock()
+ defer server.pingLock.Unlock()
+ if server.pingID == 0 {
+ t.Fatalf("should ping")
+ }
+}
+
+func TestKeepAlive_Timeout(t *testing.T) {
+ conn1, conn2 := testConn()
+
+ clientConf := testConf()
+ clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
+ clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
+ client, _ := Client(conn1, clientConf)
+ defer client.Close()
+
+ server, _ := Server(conn2, testConf())
+ defer server.Close()
+
+ _ = captureLogs(client) // Client logs aren't part of the test
+ serverLogs := captureLogs(server)
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := server.Accept() // Wait until server closes
+ errCh <- err
+ }()
+
+ // Prevent the client from responding
+ clientConn := client.conn.(*pipeConn)
+ clientConn.writeBlocker.Lock()
+
+ select {
+ case err := <-errCh:
+ if err != ErrKeepAliveTimeout {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("timeout waiting for timeout")
+ }
+
+ if !server.IsClosed() {
+ t.Fatalf("server should have closed")
+ }
+
+ if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
+ t.Fatalf("server log incorect: %v", serverLogs.logs())
+ }
+}
+
+func TestLargeWindow(t *testing.T) {
+ conf := DefaultConfig()
+ conf.MaxStreamWindowSize *= 2
+
+ client, server := testClientServerConfig(conf)
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
+ buf := make([]byte, conf.MaxStreamWindowSize)
+ n, err := stream.Write(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != len(buf) {
+ t.Fatalf("short write: %d", n)
+ }
+}
+
+type UnlimitedReader struct{}
+
+func (u *UnlimitedReader) Read(p []byte) (int, error) {
+ runtime.Gosched()
+ return len(p), nil
+}
+
+func TestSendData_VeryLarge(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ var n int64 = 1 * 1024 * 1024 * 1024
+ var workers int = 16
+
+ wg := &sync.WaitGroup{}
+ wg.Add(workers * 2)
+
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ _, err = stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
+ t.Fatalf("bad header")
+ }
+
+ recv, err := io.Copy(ioutil.Discard, stream)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if recv != n {
+ t.Fatalf("bad: %v", recv)
+ }
+ }()
+ }
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ _, err = stream.Write([]byte{0, 1, 2, 3})
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ unlimited := &UnlimitedReader{}
+ sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if sent != n {
+ t.Fatalf("bad: %v", sent)
+ }
+ }()
+ }
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case <-doneCh:
+ case <-time.After(20 * time.Second):
+ panic("timeout")
+ }
+}
+
+func TestBacklogExceeded_Accept(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ max := 5 * client.config.AcceptBacklog
+ go func() {
+ for i := 0; i < max; i++ {
+ stream, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ }
+ }()
+
+ // Fill the backlog
+ for i := 0; i < max; i++ {
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ if _, err := stream.Write([]byte("foo")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+}
+
+func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ // Choose a huge flood size that we know will result in a window update.
+ flood := int64(client.config.MaxStreamWindowSize) - 1
+
+ // The server will accept a new stream and then flood data to it.
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ n, err := stream.Write(make([]byte, flood))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if int64(n) != flood {
+ t.Fatalf("short write: %d", n)
+ }
+ }()
+
+ // The client will open a stream, block outbound writes, and then
+ // listen to the flood from the server, which should time out since
+ // it won't be able to send the window update.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ _, err = stream.Read(make([]byte, flood))
+ if err != ErrConnectionWriteTimeout {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_PartialReadWindowUpdate(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ // Choose a huge flood size that we know will result in a window update.
+ flood := int64(client.config.MaxStreamWindowSize)
+ var wr *Stream
+
+ // The server will accept a new stream and then flood data to it.
+ go func() {
+ defer wg.Done()
+
+ var err error
+ wr, err = server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer wr.Close()
+
+ if wr.sendWindow != client.config.MaxStreamWindowSize {
+ t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
+ }
+
+ n, err := wr.Write(make([]byte, flood))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if int64(n) != flood {
+ t.Fatalf("short write: %d", n)
+ }
+ if wr.sendWindow != 0 {
+ t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
+ }
+ }()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ wg.Wait()
+
+ _, err = stream.Read(make([]byte, flood/2+1))
+
+ if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
+ t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
+ }
+}
+
+func TestSession_sendNoWait_Timeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ }()
+
+ // The client will open the stream and then block outbound writes, we'll
+ // probe sendNoWait once it gets into that state.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagACK, 0, 0)
+ for {
+ err = client.sendNoWait(hdr)
+ if err == nil {
+ continue
+ } else if err == ErrConnectionWriteTimeout {
+ break
+ } else {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_PingOfDeath(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ var doPingOfDeath sync.Mutex
+ doPingOfDeath.Lock()
+
+ // This is used later to block outbound writes.
+ conn := server.conn.(*pipeConn)
+
+ // The server will accept a stream, block outbound writes, and then
+ // flood its send channel so that no more headers can be queued.
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn.writeBlocker.Lock()
+ for {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, 0, 0, 0)
+ err = server.sendNoWait(hdr)
+ if err == nil {
+ continue
+ } else if err == ErrConnectionWriteTimeout {
+ break
+ } else {
+ t.Fatalf("err: %v", err)
+ }
+ }
+
+ doPingOfDeath.Unlock()
+ }()
+
+ // The client will open a stream and then send the server a ping once it
+ // can no longer write. This makes sure the server doesn't deadlock reads
+ // while trying to reply to the ping with no ability to write.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ // This ping will never unblock because the ping id will never
+ // show up in a response.
+ doPingOfDeath.Lock()
+ go func() { client.Ping() }()
+
+ // Wait for a while to make sure the previous ping times out,
+ // then turn writes back on and make sure a ping works again.
+ time.Sleep(2 * server.config.ConnectionWriteTimeout)
+ conn.writeBlocker.Unlock()
+ if _, err = client.Ping(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_ConnectionWriteTimeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ }()
+
+ // The client will open the stream and then block outbound writes, we'll
+ // tee up a write and make sure it eventually times out.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ // Since the write goroutine is blocked then this will return a
+ // timeout since it can't get feedback about whether the write
+ // worked.
+ n, err := stream.Write([]byte("hello"))
+ if err != ErrConnectionWriteTimeout {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("lied about writes: %d", n)
+ }
+ }()
+
+ wg.Wait()
+}
--- /dev/null
+# Specification
+
+We use this document to detail the internal specification of Yamux.
+This is used both as a guide for implementing Yamux, but also for
+alternative interoperable libraries to be built.
+
+# Framing
+
+Yamux uses a streaming connection underneath, but imposes a message
+framing so that it can be shared between many logical streams. Each
+frame contains a header like:
+
+* Version (8 bits)
+* Type (8 bits)
+* Flags (16 bits)
+* StreamID (32 bits)
+* Length (32 bits)
+
+This means that each header has a 12 byte overhead.
+All fields are encoded in network order (big endian).
+Each field is described below:
+
+## Version Field
+
+The version field is used for future backward compatibility. At the
+current time, the field is always set to 0, to indicate the initial
+version.
+
+## Type Field
+
+The type field is used to switch the frame message type. The following
+message types are supported:
+
+* 0x0 Data - Used to transmit data. May transmit zero length payloads
+ depending on the flags.
+
+* 0x1 Window Update - Used to updated the senders receive window size.
+ This is used to implement per-session flow control.
+
+* 0x2 Ping - Used to measure RTT. It can also be used to heart-beat
+ and do keep-alives over TCP.
+
+* 0x3 Go Away - Used to close a session.
+
+## Flag Field
+
+The flags field is used to provide additional information related
+to the message type. The following flags are supported:
+
+* 0x1 SYN - Signals the start of a new stream. May be sent with a data or
+ window update message. Also sent with a ping to indicate outbound.
+
+* 0x2 ACK - Acknowledges the start of a new stream. May be sent with a data
+ or window update message. Also sent with a ping to indicate response.
+
+* 0x4 FIN - Performs a half-close of a stream. May be sent with a data
+ message or window update.
+
+* 0x8 RST - Reset a stream immediately. May be sent with a data or
+ window update message.
+
+## StreamID Field
+
+The StreamID field is used to identify the logical stream the frame
+is addressing. The client side should use odd ID's, and the server even.
+This prevents any collisions. Additionally, the 0 ID is reserved to represent
+the session.
+
+Both Ping and Go Away messages should always use the 0 StreamID.
+
+## Length Field
+
+The meaning of the length field depends on the message type:
+
+* Data - provides the length of bytes following the header
+* Window update - provides a delta update to the window size
+* Ping - Contains an opaque value, echoed back
+* Go Away - Contains an error code
+
+# Message Flow
+
+There is no explicit connection setup, as Yamux relies on an underlying
+transport to be provided. However, there is a distinction between client
+and server side of the connection.
+
+## Opening a stream
+
+To open a stream, an initial data or window update frame is sent
+with a new StreamID. The SYN flag should be set to signal a new stream.
+
+The receiver must then reply with either a data or window update frame
+with the StreamID along with the ACK flag to accept the stream or with
+the RST flag to reject the stream.
+
+Because we are relying on the reliable stream underneath, a connection
+can begin sending data once the SYN flag is sent. The corresponding
+ACK does not need to be received. This is particularly well suited
+for an RPC system where a client wants to open a stream and immediately
+fire a request without waiting for the RTT of the ACK.
+
+This does introduce the possibility of a connection being rejected
+after data has been sent already. This is a slight semantic difference
+from TCP, where the conection cannot be refused after it is opened.
+Clients should be prepared to handle this by checking for an error
+that indicates a RST was received.
+
+## Closing a stream
+
+To close a stream, either side sends a data or window update frame
+along with the FIN flag. This does a half-close indicating the sender
+will send no further data.
+
+Once both sides have closed the connection, the stream is closed.
+
+Alternatively, if an error occurs, the RST flag can be used to
+hard close a stream immediately.
+
+## Flow Control
+
+When Yamux is initially starts each stream with a 256KB window size.
+There is no window size for the session.
+
+To prevent the streams from stalling, window update frames should be
+sent regularly. Yamux can be configured to provide a larger limit for
+windows sizes. Both sides assume the initial 256KB window, but can
+immediately send a window update as part of the SYN/ACK indicating a
+larger window.
+
+Both sides should track the number of bytes sent in Data frames
+only, as only they are tracked as part of the window size.
+
+## Session termination
+
+When a session is being terminated, the Go Away message should
+be sent. The Length should be set to one of the following to
+provide an error code:
+
+* 0x0 Normal termination
+* 0x1 Protocol error
+* 0x2 Internal error
--- /dev/null
+package yamux
+
+import (
+ "bytes"
+ "io"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+type streamState int
+
+const (
+ streamInit streamState = iota
+ streamSYNSent
+ streamSYNReceived
+ streamEstablished
+ streamLocalClose
+ streamRemoteClose
+ streamClosed
+ streamReset
+)
+
+// Stream is used to represent a logical stream
+// within a session.
+type Stream struct {
+ recvWindow uint32
+ sendWindow uint32
+
+ id uint32
+ session *Session
+
+ state streamState
+ stateLock sync.Mutex
+
+ recvBuf *bytes.Buffer
+ recvLock sync.Mutex
+
+ controlHdr header
+ controlErr chan error
+ controlHdrLock sync.Mutex
+
+ sendHdr header
+ sendErr chan error
+ sendLock sync.Mutex
+
+ recvNotifyCh chan struct{}
+ sendNotifyCh chan struct{}
+
+ readDeadline atomic.Value // time.Time
+ writeDeadline atomic.Value // time.Time
+}
+
+// newStream is used to construct a new stream within
+// a given session for an ID
+func newStream(session *Session, id uint32, state streamState) *Stream {
+ s := &Stream{
+ id: id,
+ session: session,
+ state: state,
+ controlHdr: header(make([]byte, headerSize)),
+ controlErr: make(chan error, 1),
+ sendHdr: header(make([]byte, headerSize)),
+ sendErr: make(chan error, 1),
+ recvWindow: initialStreamWindow,
+ sendWindow: initialStreamWindow,
+ recvNotifyCh: make(chan struct{}, 1),
+ sendNotifyCh: make(chan struct{}, 1),
+ }
+ s.readDeadline.Store(time.Time{})
+ s.writeDeadline.Store(time.Time{})
+ return s
+}
+
+// Session returns the associated stream session
+func (s *Stream) Session() *Session {
+ return s.session
+}
+
+// StreamID returns the ID of this stream
+func (s *Stream) StreamID() uint32 {
+ return s.id
+}
+
+// Read is used to read from the stream
+func (s *Stream) Read(b []byte) (n int, err error) {
+ defer asyncNotify(s.recvNotifyCh)
+START:
+ s.stateLock.Lock()
+ switch s.state {
+ case streamLocalClose:
+ fallthrough
+ case streamRemoteClose:
+ fallthrough
+ case streamClosed:
+ s.recvLock.Lock()
+ if s.recvBuf == nil || s.recvBuf.Len() == 0 {
+ s.recvLock.Unlock()
+ s.stateLock.Unlock()
+ return 0, io.EOF
+ }
+ s.recvLock.Unlock()
+ case streamReset:
+ s.stateLock.Unlock()
+ return 0, ErrConnectionReset
+ }
+ s.stateLock.Unlock()
+
+ // If there is no data available, block
+ s.recvLock.Lock()
+ if s.recvBuf == nil || s.recvBuf.Len() == 0 {
+ s.recvLock.Unlock()
+ goto WAIT
+ }
+
+ // Read any bytes
+ n, _ = s.recvBuf.Read(b)
+ s.recvLock.Unlock()
+
+ // Send a window update potentially
+ err = s.sendWindowUpdate()
+ return n, err
+
+WAIT:
+ var timeout <-chan time.Time
+ var timer *time.Timer
+ readDeadline := s.readDeadline.Load().(time.Time)
+ if !readDeadline.IsZero() {
+ delay := readDeadline.Sub(time.Now())
+ timer = time.NewTimer(delay)
+ timeout = timer.C
+ }
+ select {
+ case <-s.recvNotifyCh:
+ if timer != nil {
+ timer.Stop()
+ }
+ goto START
+ case <-timeout:
+ return 0, ErrTimeout
+ }
+}
+
+// Write is used to write to the stream
+func (s *Stream) Write(b []byte) (n int, err error) {
+ s.sendLock.Lock()
+ defer s.sendLock.Unlock()
+ total := 0
+ for total < len(b) {
+ n, err := s.write(b[total:])
+ total += n
+ if err != nil {
+ return total, err
+ }
+ }
+ return total, nil
+}
+
+// write is used to write to the stream, may return on
+// a short write.
+func (s *Stream) write(b []byte) (n int, err error) {
+ var flags uint16
+ var max uint32
+ var body io.Reader
+START:
+ s.stateLock.Lock()
+ switch s.state {
+ case streamLocalClose:
+ fallthrough
+ case streamClosed:
+ s.stateLock.Unlock()
+ return 0, ErrStreamClosed
+ case streamReset:
+ s.stateLock.Unlock()
+ return 0, ErrConnectionReset
+ }
+ s.stateLock.Unlock()
+
+ // If there is no data available, block
+ window := atomic.LoadUint32(&s.sendWindow)
+ if window == 0 {
+ goto WAIT
+ }
+
+ // Determine the flags if any
+ flags = s.sendFlags()
+
+ // Send up to our send window
+ max = min(window, uint32(len(b)))
+ body = bytes.NewReader(b[:max])
+
+ // Send the header
+ s.sendHdr.encode(typeData, flags, s.id, max)
+ if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
+ return 0, err
+ }
+
+ // Reduce our send window
+ atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
+
+ // Unlock
+ return int(max), err
+
+WAIT:
+ var timeout <-chan time.Time
+ writeDeadline := s.writeDeadline.Load().(time.Time)
+ if !writeDeadline.IsZero() {
+ delay := writeDeadline.Sub(time.Now())
+ timeout = time.After(delay)
+ }
+ select {
+ case <-s.sendNotifyCh:
+ goto START
+ case <-timeout:
+ return 0, ErrTimeout
+ }
+ return 0, nil
+}
+
+// sendFlags determines any flags that are appropriate
+// based on the current stream state
+func (s *Stream) sendFlags() uint16 {
+ s.stateLock.Lock()
+ defer s.stateLock.Unlock()
+ var flags uint16
+ switch s.state {
+ case streamInit:
+ flags |= flagSYN
+ s.state = streamSYNSent
+ case streamSYNReceived:
+ flags |= flagACK
+ s.state = streamEstablished
+ }
+ return flags
+}
+
+// sendWindowUpdate potentially sends a window update enabling
+// further writes to take place. Must be invoked with the lock.
+func (s *Stream) sendWindowUpdate() error {
+ s.controlHdrLock.Lock()
+ defer s.controlHdrLock.Unlock()
+
+ // Determine the delta update
+ max := s.session.config.MaxStreamWindowSize
+ var bufLen uint32
+ s.recvLock.Lock()
+ if s.recvBuf != nil {
+ bufLen = uint32(s.recvBuf.Len())
+ }
+ delta := (max - bufLen) - s.recvWindow
+
+ // Determine the flags if any
+ flags := s.sendFlags()
+
+ // Check if we can omit the update
+ if delta < (max/2) && flags == 0 {
+ s.recvLock.Unlock()
+ return nil
+ }
+
+ // Update our window
+ s.recvWindow += delta
+ s.recvLock.Unlock()
+
+ // Send the header
+ s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
+ if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
+ return err
+ }
+ return nil
+}
+
+// sendClose is used to send a FIN
+func (s *Stream) sendClose() error {
+ s.controlHdrLock.Lock()
+ defer s.controlHdrLock.Unlock()
+
+ flags := s.sendFlags()
+ flags |= flagFIN
+ s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
+ if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Close is used to close the stream
+func (s *Stream) Close() error {
+ closeStream := false
+ s.stateLock.Lock()
+ switch s.state {
+ // Opened means we need to signal a close
+ case streamSYNSent:
+ fallthrough
+ case streamSYNReceived:
+ fallthrough
+ case streamEstablished:
+ s.state = streamLocalClose
+ goto SEND_CLOSE
+
+ case streamLocalClose:
+ case streamRemoteClose:
+ s.state = streamClosed
+ closeStream = true
+ goto SEND_CLOSE
+
+ case streamClosed:
+ case streamReset:
+ default:
+ panic("unhandled state")
+ }
+ s.stateLock.Unlock()
+ return nil
+SEND_CLOSE:
+ s.stateLock.Unlock()
+ s.sendClose()
+ s.notifyWaiting()
+ if closeStream {
+ s.session.closeStream(s.id)
+ }
+ return nil
+}
+
+// forceClose is used for when the session is exiting
+func (s *Stream) forceClose() {
+ s.stateLock.Lock()
+ s.state = streamClosed
+ s.stateLock.Unlock()
+ s.notifyWaiting()
+}
+
+// processFlags is used to update the state of the stream
+// based on set flags, if any. Lock must be held
+func (s *Stream) processFlags(flags uint16) error {
+ // Close the stream without holding the state lock
+ closeStream := false
+ defer func() {
+ if closeStream {
+ s.session.closeStream(s.id)
+ }
+ }()
+
+ s.stateLock.Lock()
+ defer s.stateLock.Unlock()
+ if flags&flagACK == flagACK {
+ if s.state == streamSYNSent {
+ s.state = streamEstablished
+ }
+ s.session.establishStream(s.id)
+ }
+ if flags&flagFIN == flagFIN {
+ switch s.state {
+ case streamSYNSent:
+ fallthrough
+ case streamSYNReceived:
+ fallthrough
+ case streamEstablished:
+ s.state = streamRemoteClose
+ s.notifyWaiting()
+ case streamLocalClose:
+ s.state = streamClosed
+ closeStream = true
+ s.notifyWaiting()
+ default:
+ s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
+ return ErrUnexpectedFlag
+ }
+ }
+ if flags&flagRST == flagRST {
+ s.state = streamReset
+ closeStream = true
+ s.notifyWaiting()
+ }
+ return nil
+}
+
+// notifyWaiting notifies all the waiting channels
+func (s *Stream) notifyWaiting() {
+ asyncNotify(s.recvNotifyCh)
+ asyncNotify(s.sendNotifyCh)
+}
+
+// incrSendWindow updates the size of our send window
+func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
+ if err := s.processFlags(flags); err != nil {
+ return err
+ }
+
+ // Increase window, unblock a sender
+ atomic.AddUint32(&s.sendWindow, hdr.Length())
+ asyncNotify(s.sendNotifyCh)
+ return nil
+}
+
+// readData is used to handle a data frame
+func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
+ if err := s.processFlags(flags); err != nil {
+ return err
+ }
+
+ // Check that our recv window is not exceeded
+ length := hdr.Length()
+ if length == 0 {
+ return nil
+ }
+
+ // Wrap in a limited reader
+ conn = &io.LimitedReader{R: conn, N: int64(length)}
+
+ // Copy into buffer
+ s.recvLock.Lock()
+
+ if length > s.recvWindow {
+ s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
+ return ErrRecvWindowExceeded
+ }
+
+ if s.recvBuf == nil {
+ // Allocate the receive buffer just-in-time to fit the full data frame.
+ // This way we can read in the whole packet without further allocations.
+ s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
+ }
+ if _, err := io.Copy(s.recvBuf, conn); err != nil {
+ s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
+ s.recvLock.Unlock()
+ return err
+ }
+
+ // Decrement the receive window
+ s.recvWindow -= length
+ s.recvLock.Unlock()
+
+ // Unblock any readers
+ asyncNotify(s.recvNotifyCh)
+ return nil
+}
+
+// SetDeadline sets the read and write deadlines
+func (s *Stream) SetDeadline(t time.Time) error {
+ if err := s.SetReadDeadline(t); err != nil {
+ return err
+ }
+ if err := s.SetWriteDeadline(t); err != nil {
+ return err
+ }
+ return nil
+}
+
+// SetReadDeadline sets the deadline for future Read calls.
+func (s *Stream) SetReadDeadline(t time.Time) error {
+ s.readDeadline.Store(t)
+ return nil
+}
+
+// SetWriteDeadline sets the deadline for future Write calls
+func (s *Stream) SetWriteDeadline(t time.Time) error {
+ s.writeDeadline.Store(t)
+ return nil
+}
+
+// Shrink is used to compact the amount of buffers utilized
+// This is useful when using Yamux in a connection pool to reduce
+// the idle memory utilization.
+func (s *Stream) Shrink() {
+ s.recvLock.Lock()
+ if s.recvBuf != nil && s.recvBuf.Len() == 0 {
+ s.recvBuf = nil
+ }
+ s.recvLock.Unlock()
+}
--- /dev/null
+package yamux
+
+import (
+ "sync"
+ "time"
+)
+
+var (
+ timerPool = &sync.Pool{
+ New: func() interface{} {
+ timer := time.NewTimer(time.Hour * 1e6)
+ timer.Stop()
+ return timer
+ },
+ }
+)
+
+// asyncSendErr is used to try an async send of an error
+func asyncSendErr(ch chan error, err error) {
+ if ch == nil {
+ return
+ }
+ select {
+ case ch <- err:
+ default:
+ }
+}
+
+// asyncNotify is used to signal a waiting goroutine
+func asyncNotify(ch chan struct{}) {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+}
+
+// min computes the minimum of two values
+func min(a, b uint32) uint32 {
+ if a < b {
+ return a
+ }
+ return b
+}
--- /dev/null
+package yamux
+
+import (
+ "testing"
+)
+
+func TestAsyncSendErr(t *testing.T) {
+ ch := make(chan error)
+ asyncSendErr(ch, ErrTimeout)
+ select {
+ case <-ch:
+ t.Fatalf("should not get")
+ default:
+ }
+
+ ch = make(chan error, 1)
+ asyncSendErr(ch, ErrTimeout)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("should get")
+ }
+}
+
+func TestAsyncNotify(t *testing.T) {
+ ch := make(chan struct{})
+ asyncNotify(ch)
+ select {
+ case <-ch:
+ t.Fatalf("should not get")
+ default:
+ }
+
+ ch = make(chan struct{}, 1)
+ asyncNotify(ch)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("should get")
+ }
+}
+
+func TestMin(t *testing.T) {
+ if min(1, 2) != 1 {
+ t.Fatalf("bad")
+ }
+ if min(2, 1) != 1 {
+ t.Fatalf("bad")
+ }
+}