diff --git a/lib/targets.go b/lib/targets.go index 9624aa2f..3da9f45a 100644 --- a/lib/targets.go +++ b/lib/targets.go @@ -68,18 +68,19 @@ func NewStaticTargeter(tgts ...Target) Targeter { // returns a NewStaticTargeter with them. // // body will be set as the Target's body if no body is provided. -// hdr will be merged with the each Target's headers. +// header will be merged with the each Target's headers. func NewEagerTargeter(src io.Reader, body []byte, header http.Header) (Targeter, error) { var ( - sc = NewLazyTargeter(src, body, header) tgts []Target tgt Target - err error ) + reader := newReader(src, body, header) for { - if err = sc(&tgt); err == ErrNoTargets { + err := reader.load(&tgt) + if err == ErrNoTargets { break - } else if err != nil { + } + if err != nil { return nil, err } tgts = append(tgts, tgt) @@ -94,79 +95,107 @@ func NewEagerTargeter(src io.Reader, body []byte, header http.Header) (Targeter, // provided io.Reader on every invocation. // // body will be set as the Target's body if no body is provided. -// hdr will be merged with the each Target's headers. -func NewLazyTargeter(src io.Reader, body []byte, hdr http.Header) Targeter { +// header will be merged with the each Target's headers. +func NewLazyTargeter(src io.Reader, body []byte, header http.Header) Targeter { var mu sync.Mutex - sc := peekingScanner{src: bufio.NewScanner(src)} + reader := newReader(src, body, header) return func(tgt *Target) (err error) { mu.Lock() defer mu.Unlock() + return reader.load(tgt) + } +} - if tgt == nil { - return ErrNilTarget - } +type reader struct { + sc *peekingScanner + defaultBody []byte + defaultHeader http.Header +} - var line string - for { - if !sc.Scan() { - return ErrNoTargets - } - line = strings.TrimSpace(sc.Text()) - if len(line) != 0 { - break - } - } +func newReader(src io.Reader, body []byte, header http.Header) *reader { + return &reader{ + sc: &peekingScanner{src: bufio.NewScanner(src)}, + defaultBody: body, + defaultHeader: header, + } +} - tgt.Body = body - tgt.Header = http.Header{} - for k, vs := range hdr { - tgt.Header[k] = vs - } +func (r *reader) load(tgt *Target) (err error) { + if tgt == nil { + return ErrNilTarget + } - tokens := strings.SplitN(line, " ", 2) - if len(tokens) < 2 { - return fmt.Errorf("bad target: %s", line) - } - if !startsWithHTTPMethod(line) { - return fmt.Errorf("bad method: %s", tokens[0]) - } - tgt.Method = tokens[0] - if _, err = url.ParseRequestURI(tokens[1]); err != nil { - return fmt.Errorf("bad URL: %s", tokens[1]) + var line string + for { + if !r.sc.Scan() { + return ErrNoTargets } - tgt.URL = tokens[1] - line = strings.TrimSpace(sc.Peek()) - if line == "" || startsWithHTTPMethod(line) { - return nil + line = strings.TrimSpace(r.sc.Text()) + if len(line) != 0 { + break } - for sc.Scan() { - if line = strings.TrimSpace(sc.Text()); line == "" { - break - } else if strings.HasPrefix(line, "@") { - if tgt.Body, err = ioutil.ReadFile(line[1:]); err != nil { - return fmt.Errorf("bad body: %s", err) + } + + tokens := strings.SplitN(line, " ", 2) + if len(tokens) < 2 { + return fmt.Errorf("bad target: %s", line) + } + if !startsWithHTTPMethod(line) { + return fmt.Errorf("bad method: %s", tokens[0]) + } + if _, err = url.ParseRequestURI(tokens[1]); err != nil { + return fmt.Errorf("bad URL: %s", tokens[1]) + } + tgt.Body = r.defaultBody + tgt.Header = http.Header{} + for k, vs := range r.defaultHeader { + tgt.Header[k] = vs + } + tgt.Method = tokens[0] + tgt.URL = tokens[1] + + line = strings.TrimSpace(r.sc.Peek()) + if line == "" || startsWithHTTPMethod(line) { + return nil + } + for r.sc.Scan() { + if line = strings.TrimSpace(r.sc.Text()); line == "" { + break + } else if strings.HasPrefix(line, "@") { + if strings.HasPrefix(line, "@<<") { + tag := line[3:] + buf := bytes.Buffer{} + for r.sc.Scan() { + line = r.sc.Text() + if line == tag { + break + } + buf.WriteString(line + "\n") } - break + tgt.Body = buf.Bytes() + } else if tgt.Body, err = ioutil.ReadFile(line[1:]); err != nil { + return fmt.Errorf("bad body: %s", err) } - tokens = strings.SplitN(line, ":", 2) - if len(tokens) < 2 { + break + } + tokens = strings.SplitN(line, ":", 2) + if len(tokens) < 2 { + return fmt.Errorf("bad header: %s", line) + } + for i := range tokens { + if tokens[i] = strings.TrimSpace(tokens[i]); tokens[i] == "" { return fmt.Errorf("bad header: %s", line) } - for i := range tokens { - if tokens[i] = strings.TrimSpace(tokens[i]); tokens[i] == "" { - return fmt.Errorf("bad header: %s", line) - } - } - // Add key/value directly to the http.Header (map[string][]string). - // http.Header.Add() canonicalizes keys but vegeta is used - // to test systems that require case-sensitive headers. - tgt.Header[tokens[0]] = append(tgt.Header[tokens[0]], tokens[1]) - } - if err = sc.Err(); err != nil { - return ErrNoTargets } - return nil + // Add key/value directly to the http.Header (map[string][]string). + // http.Header.Add() canonicalizes keys but vegeta is used + // to test systems that require case-sensitive headers. + tgt.Header[tokens[0]] = append(tgt.Header[tokens[0]], tokens[1]) + } + if err = r.sc.Err(); err != nil { + return ErrNoTargets } + return nil } var httpMethodChecker = regexp.MustCompile("^[A-Z]+\\s") diff --git a/lib/targets_test.go b/lib/targets_test.go index c645661f..3395d2fc 100644 --- a/lib/targets_test.go +++ b/lib/targets_test.go @@ -145,6 +145,11 @@ func TestNewLazyTargeter(t *testing.T) { @`, bodyf.Name(), ` + POST http://foobar.org/herebody + @<