transfer.sh/vendor/github.com/google/martian/body/body_modifier_test.go
2019-03-17 20:19:56 +01:00

325 lines
9.6 KiB
Go

// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package body
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"strings"
"testing"
"github.com/google/martian/messageview"
"github.com/google/martian/parse"
"github.com/google/martian/proxyutil"
)
func TestBodyModifier(t *testing.T) {
mod := NewModifier([]byte("text"), "text/plain")
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest(): got %v, want no error", err)
}
req.Header.Set("Content-Encoding", "gzip")
if err := mod.ModifyRequest(req); err != nil {
t.Fatalf("ModifyRequest(): got %v, want no error", err)
}
if got, want := req.Header.Get("Content-Type"), "text/plain"; got != want {
t.Errorf("req.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
}
if got, want := req.ContentLength, int64(len([]byte("text"))); got != want {
t.Errorf("req.ContentLength: got %d, want %d", got, want)
}
if got, want := req.Header.Get("Content-Encoding"), ""; got != want {
t.Errorf("req.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
}
got, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
}
req.Body.Close()
if want := []byte("text"); !bytes.Equal(got, want) {
t.Errorf("res.Body: got %q, want %q", got, want)
}
res := proxyutil.NewResponse(200, nil, req)
res.Header.Set("Content-Encoding", "gzip")
if err := mod.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want {
t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
}
if got, want := res.ContentLength, int64(len([]byte("text"))); got != want {
t.Errorf("res.ContentLength: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Content-Encoding"), ""; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
}
got, err = ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
}
res.Body.Close()
if want := []byte("text"); !bytes.Equal(got, want) {
t.Errorf("res.Body: got %q, want %q", got, want)
}
}
func TestRangeHeaderRequestSingleRange(t *testing.T) {
mod := NewModifier([]byte("0123456789"), "text/plain")
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest(): got %v, want no error", err)
}
req.Header.Set("Range", "bytes=1-4")
res := proxyutil.NewResponse(200, nil, req)
if err := mod.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, http.StatusPartialContent; got != want {
t.Errorf("res.Status: got %v, want %v", got, want)
}
if got, want := res.ContentLength, int64(len([]byte("1234"))); got != want {
t.Errorf("res.ContentLength: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Content-Range"), "bytes 1-4/10"; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
}
res.Body.Close()
if want := []byte("1234"); !bytes.Equal(got, want) {
t.Errorf("res.Body: got %q, want %q", got, want)
}
}
func TestRangeHeaderRequestSingleRangeHasAllTheBytes(t *testing.T) {
mod := NewModifier([]byte("0123456789"), "text/plain")
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest(): got %v, want no error", err)
}
req.Header.Set("Range", "bytes=0-")
res := proxyutil.NewResponse(200, nil, req)
if err := mod.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, http.StatusPartialContent; got != want {
t.Errorf("res.Status: got %v, want %v", got, want)
}
if got, want := res.ContentLength, int64(len([]byte("0123456789"))); got != want {
t.Errorf("res.ContentLength: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Content-Range"), "bytes 0-9/10"; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
}
res.Body.Close()
if want := []byte("0123456789"); !bytes.Equal(got, want) {
t.Errorf("res.Body: got %q, want %q", got, want)
}
}
func TestRangeNoEndingIndexSpecified(t *testing.T) {
mod := NewModifier([]byte("0123456789"), "text/plain")
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest(): got %v, want no error", err)
}
req.Header.Set("Range", "bytes=8-")
res := proxyutil.NewResponse(200, nil, req)
if err := mod.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, http.StatusPartialContent; got != want {
t.Errorf("res.Status: got %v, want %v", got, want)
}
if got, want := res.ContentLength, int64(len([]byte("89"))); got != want {
t.Errorf("res.ContentLength: got %d, want %d", got, want)
}
if got, want := res.Header.Get("Content-Range"), "bytes 8-9/10"; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
}
}
func TestRangeHeaderMultipartRange(t *testing.T) {
mod := NewModifier([]byte("0123456789"), "text/plain")
bndry := "3d6b6a416f9b5"
mod.SetBoundary(bndry)
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest(): got %v, want no error", err)
}
req.Header.Set("Range", "bytes=1-4, 7-9")
res := proxyutil.NewResponse(200, nil, req)
if err := mod.ModifyResponse(res); err != nil {
t.Fatalf("ModifyResponse(): got %v, want no error", err)
}
if got, want := res.StatusCode, http.StatusPartialContent; got != want {
t.Errorf("res.Status: got %v, want %v", got, want)
}
if got, want := res.Header.Get("Content-Type"), "multipart/byteranges; boundary=3d6b6a416f9b5"; got != want {
t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
}
mv := messageview.New()
if err := mv.SnapshotResponse(res); err != nil {
t.Fatalf("mv.SnapshotResponse(res): got %v, want no error", err)
}
br, err := mv.BodyReader()
if err != nil {
t.Fatalf("mv.BodyReader(): got %v, want no error", err)
}
mpr := multipart.NewReader(br, bndry)
prt1, err := mpr.NextPart()
if err != nil {
t.Fatalf("mpr.NextPart(): got %v, want no error", err)
}
defer prt1.Close()
if got, want := prt1.Header.Get("Content-Type"), "text/plain"; got != want {
t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
}
if got, want := prt1.Header.Get("Content-Range"), "bytes 1-4/10"; got != want {
t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Range", got, want)
}
prt1b, err := ioutil.ReadAll(prt1)
if err != nil {
t.Errorf("ioutil.Readall(prt1): got %v, want no error", err)
}
if got, want := string(prt1b), "1234"; got != want {
t.Errorf("prt1 body: got %s, want %s", got, want)
}
prt2, err := mpr.NextPart()
if err != nil {
t.Fatalf("mpr.NextPart(): got %v, want no error", err)
}
defer prt2.Close()
if got, want := prt2.Header.Get("Content-Type"), "text/plain"; got != want {
t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
}
if got, want := prt2.Header.Get("Content-Range"), "bytes 7-9/10"; got != want {
t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Range", got, want)
}
prt2b, err := ioutil.ReadAll(prt2)
if err != io.ErrUnexpectedEOF && err != nil {
t.Errorf("ioutil.Readall(prt2): got %v, want no error", err)
}
if got, want := string(prt2b), "789"; got != want {
t.Errorf("prt2 body: got %s, want %s", got, want)
}
_, err = mpr.NextPart()
if err == nil {
t.Errorf("mpr.NextPart: want io.EOF, got no error")
}
if err != io.EOF {
t.Errorf("mpr.NextPart: want io.EOF, got %v", err)
}
}
func TestModifierFromJSON(t *testing.T) {
data := base64.StdEncoding.EncodeToString([]byte("data"))
msg := fmt.Sprintf(`{
"body.Modifier":{
"scope": ["response"],
"contentType": "text/plain",
"body": %q
}
}`, data)
r, err := parse.FromJSON([]byte(msg))
if err != nil {
t.Fatalf("parse.FromJSON(): got %v, want no error", err)
}
resmod := r.ResponseModifier()
if resmod == nil {
t.Fatalf("resmod: got nil, want not nil")
}
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest(): got %v, want no error", err)
}
res := proxyutil.NewResponse(200, nil, req)
if err := resmod.ModifyResponse(res); err != nil {
t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err)
}
if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want {
t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
}
res.Body.Close()
if want := []byte("data"); !bytes.Equal(got, want) {
t.Errorf("res.Body: got %q, want %q", got, want)
}
}