mirror of
https://github.com/dutchcoders/transfer.sh.git
synced 2024-11-27 14:40:18 +01:00
200 lines
4.6 KiB
Go
200 lines
4.6 KiB
Go
// Copyright 2017 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package pubsub
|
|
|
|
// This file provides a mock in-memory pubsub server for streaming pull testing.
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
"cloud.google.com/go/internal/testutil"
|
|
emptypb "github.com/golang/protobuf/ptypes/empty"
|
|
pb "google.golang.org/genproto/googleapis/pubsub/v1"
|
|
)
|
|
|
|
type mockServer struct {
|
|
srv *testutil.Server
|
|
|
|
pb.SubscriberServer
|
|
|
|
Addr string
|
|
|
|
mu sync.Mutex
|
|
Acked map[string]bool // acked message IDs
|
|
Deadlines map[string]int32 // deadlines by message ID
|
|
pullResponses []*pullResponse
|
|
ackErrs []error
|
|
modAckErrs []error
|
|
wg sync.WaitGroup
|
|
sub *pb.Subscription
|
|
}
|
|
|
|
type pullResponse struct {
|
|
msgs []*pb.ReceivedMessage
|
|
err error
|
|
}
|
|
|
|
func newMockServer(port int) (*mockServer, error) {
|
|
srv, err := testutil.NewServerWithPort(port)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mock := &mockServer{
|
|
srv: srv,
|
|
Addr: srv.Addr,
|
|
Acked: map[string]bool{},
|
|
Deadlines: map[string]int32{},
|
|
sub: &pb.Subscription{
|
|
AckDeadlineSeconds: 10,
|
|
PushConfig: &pb.PushConfig{},
|
|
},
|
|
}
|
|
pb.RegisterSubscriberServer(srv.Gsrv, mock)
|
|
srv.Start()
|
|
return mock, nil
|
|
}
|
|
|
|
// Each call to addStreamingPullMessages results in one StreamingPullResponse.
|
|
func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
|
|
s.mu.Lock()
|
|
s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *mockServer) addStreamingPullError(err error) {
|
|
s.mu.Lock()
|
|
s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *mockServer) addAckResponse(err error) {
|
|
s.mu.Lock()
|
|
s.ackErrs = append(s.ackErrs, err)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *mockServer) addModAckResponse(err error) {
|
|
s.mu.Lock()
|
|
s.modAckErrs = append(s.modAckErrs, err)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *mockServer) wait() {
|
|
s.wg.Wait()
|
|
}
|
|
|
|
func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
|
|
s.wg.Add(1)
|
|
defer s.wg.Done()
|
|
errc := make(chan error, 1)
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
for {
|
|
req, err := stream.Recv()
|
|
if err != nil {
|
|
errc <- err
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
for _, id := range req.AckIds {
|
|
s.Acked[id] = true
|
|
}
|
|
for i, id := range req.ModifyDeadlineAckIds {
|
|
s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}()
|
|
// Send responses.
|
|
for {
|
|
s.mu.Lock()
|
|
if len(s.pullResponses) == 0 {
|
|
s.mu.Unlock()
|
|
// Nothing to send, so wait for the client to shut down the stream.
|
|
err := <-errc // a real error, or at least EOF
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
pr := s.pullResponses[0]
|
|
s.pullResponses = s.pullResponses[1:]
|
|
s.mu.Unlock()
|
|
if pr.err != nil {
|
|
// Add a slight delay to ensure the server receives any
|
|
// messages en route from the client before shutting down the stream.
|
|
// This reduces flakiness of tests involving retry.
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
if pr.err == io.EOF {
|
|
return nil
|
|
}
|
|
if pr.err != nil {
|
|
return pr.err
|
|
}
|
|
// Return any error from Recv.
|
|
select {
|
|
case err := <-errc:
|
|
return err
|
|
default:
|
|
}
|
|
res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
|
|
if err := stream.Send(res); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
|
|
var err error
|
|
s.mu.Lock()
|
|
if len(s.ackErrs) > 0 {
|
|
err = s.ackErrs[0]
|
|
s.ackErrs = s.ackErrs[1:]
|
|
}
|
|
s.mu.Unlock()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, id := range req.AckIds {
|
|
s.Acked[id] = true
|
|
}
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
|
|
var err error
|
|
s.mu.Lock()
|
|
if len(s.modAckErrs) > 0 {
|
|
err = s.modAckErrs[0]
|
|
s.modAckErrs = s.modAckErrs[1:]
|
|
}
|
|
s.mu.Unlock()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, id := range req.AckIds {
|
|
s.Deadlines[id] = req.AckDeadlineSeconds
|
|
}
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
|
|
return s.sub, nil
|
|
}
|