diff --git a/.gitignore b/.gitignore index 8f13aa30e..f840f872d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ recepcert /net .container-flag* .VERSION +.python-version kubectl /receptorctl/.nox /receptorctl/.VERSION diff --git a/.golangci.yml b/.golangci.yml index 3f798b887..09764a029 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -102,6 +102,7 @@ linters-settings: - "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/logging" - "github.com/AaronH88/quic-go" + - "github.com/stretchr/testify/assert" issues: # Dont commit the following line. diff --git a/go.mod b/go.mod index f8cc90f45..6e4aa9a75 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/fortytw2/leaktest v1.3.0 github.com/fsnotify/fsnotify v1.8.0 github.com/ghjm/cmdline v0.1.2 - github.com/golang-jwt/jwt/v4 v4.5.1 + github.com/golang-jwt/jwt/v4 v4.5.2 github.com/google/go-cmp v0.6.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/gorilla/websocket v1.5.3 @@ -32,7 +32,10 @@ require ( k8s.io/client-go v0.31.3 ) -require github.com/stretchr/testify v1.10.0 // indirect +require ( + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/stretchr/testify v1.10.0 +) require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/go.sum b/go.sum index 2604061e7..62235b751 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= -github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= diff --git a/internal/version/version.go b/internal/version/version.go index 1ea9d9cb7..f2fab35fb 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -15,11 +15,8 @@ type cmdlineCfg struct{} // Run runs the action. func (cfg cmdlineCfg) Run() error { - if Version == "" { - fmt.Printf("Version unknown\n") - } else { - fmt.Printf("%s\n", Version) - } + validateVersion() + fmt.Printf("%s\n", Version) return nil } @@ -32,3 +29,11 @@ func init() { cmdline.RegisterConfigTypeForApp("receptor-version", "version", "Displays the Receptor version.", cmdlineCfg{}, cmdline.Exclusive) } + +func validateVersion() string { + if Version == "" { + return "Version unknown" + } else { + return Version + } +} diff --git a/internal/version/version_test.go b/internal/version/version_test.go new file mode 100644 index 000000000..77f0b3dcf --- /dev/null +++ b/internal/version/version_test.go @@ -0,0 +1,37 @@ +package version + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateVersion(t *testing.T) { + type test struct { + name string + input string + expected string + } + + for _, tt := range []*test{ + { + name: "version is empty string", + input: "", + expected: "Version unknown", + }, + { + name: "version is zero", + input: "0", + expected: "0", + }, + { + name: "version is one", + input: "1", + expected: "1", + }, + } { + Version = tt.input + output := validateVersion() + assert.Equal(t, tt.expected, output) + } +} diff --git a/pkg/netceptor/conn.go b/pkg/netceptor/conn.go index 07438a740..e0fa40748 100644 --- a/pkg/netceptor/conn.go +++ b/pkg/netceptor/conn.go @@ -30,6 +30,14 @@ var MaxIdleTimeoutForQuicConnections = 30 * time.Second // Having this variablized allows the tests to set KeepAliveForQuicConnections = False so that things will properly fail. var KeepAliveForQuicConnections = true +type QuicStreamForConn interface { + quic.Stream +} + +type QuicConnectionForConn interface { + quic.Connection +} + type acceptResult struct { conn net.Conn err error @@ -304,13 +312,28 @@ func (li *Listener) Addr() net.Addr { type Conn struct { s *Netceptor pc PacketConner - qc quic.Connection - qs quic.Stream + qc QuicConnectionForConn + qs QuicStreamForConn doneChan chan struct{} doneOnce *sync.Once ctx context.Context } +// NewConn constructs a new Conn instance, so that the test package can create one. +func NewConn(s *Netceptor, pc PacketConner, qc QuicConnectionForConn, qs QuicStreamForConn, doneChan chan struct{}, doneOnce *sync.Once, ctx context.Context) *Conn { + conn := &Conn{ + s: s, + pc: pc, + qc: qc, + qs: qs, + doneChan: doneChan, + doneOnce: doneOnce, + ctx: ctx, + } + + return conn +} + // Dial returns a stream connection compatible with Go's net.Conn. func (s *Netceptor) Dial(node string, service string, tlscfg *tls.Config) (*Conn, error) { return s.DialContext(context.Background(), node, service, tlscfg) @@ -416,15 +439,7 @@ func (s *Netceptor) DialContext(ctx context.Context, node string, service string return } }() - conn := &Conn{ - s: s, - pc: pc, - qc: qc, - qs: qs, - doneChan: doneChan, - doneOnce: &sync.Once{}, - ctx: cctx, - } + conn := NewConn(s, pc, qc, qs, doneChan, &sync.Once{}, cctx) return conn, nil } diff --git a/pkg/netceptor/conn_test.go b/pkg/netceptor/conn_test.go new file mode 100644 index 000000000..f5a351626 --- /dev/null +++ b/pkg/netceptor/conn_test.go @@ -0,0 +1,234 @@ +package netceptor_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/ansible/receptor/pkg/netceptor" + "github.com/ansible/receptor/pkg/netceptor/mock_netceptor" + "github.com/quic-go/quic-go" + "go.uber.org/mock/gomock" +) + +type TestConn struct { + pc netceptor.PacketConner + qc netceptor.QuicConnectionForConn + qs netceptor.QuicStreamForConn +} + +func makeConn(t testing.TB, tc TestConn) *netceptor.Conn { + t.Helper() + conn := netceptor.NewConn( + netceptor.New(context.TODO(), "test-node"), // netceptor + tc.pc, // PacketConner + tc.qc, // Connection + tc.qs, // Stream + make(chan struct{}, 1), // doneChan + &sync.Once{}, // doneOnce + context.TODO(), // context + ) + + return conn +} + +// These tests operate on the quic Stream. +func TestRead(t *testing.T) { + ctrl := gomock.NewController(t) + buf := make([]byte, 1) + // Create a mock QuicStream + mockQs := mock_netceptor.NewMockQuicStreamForConn(ctrl) + // both success and error + t.Run("Returns number of bytes from successful Read", func(t *testing.T) { + want := 1 + mockQs.EXPECT().Read(gomock.Eq(buf)).Return(want, nil).Times(1) + conn := makeConn(t, TestConn{qs: mockQs}) + got, err := conn.Read(buf) + if err != nil { + t.Fatalf("Read returned unexpected error %v", err) + } + if got != want { + t.Errorf("Wanted %v, got %v", want, got) + } + }) + + t.Run("Returns error from unsuccessful Read", func(t *testing.T) { + wantErr := errors.New("Read error") + mockQs.EXPECT().Read(gomock.Eq(buf)).Return(0, wantErr).Times(1) + conn := makeConn(t, TestConn{qs: mockQs}) + _, gotErr := conn.Read(buf) + if gotErr == nil { + t.Errorf("Read did not return expected error") + } + if gotErr != wantErr { + t.Errorf("Wanted %v, got %v", wantErr, gotErr) + } + }) +} + +func TestCancelRead(t *testing.T) { + ctrl := gomock.NewController(t) + mockQs := mock_netceptor.NewMockQuicStreamForConn(ctrl) + mockQs.EXPECT().CancelRead(gomock.Eq(quic.StreamErrorCode(499))).Times(1) + conn := makeConn(t, TestConn{qs: mockQs}) + conn.CancelRead() +} + +func TestWrite(t *testing.T) { + ctrl := gomock.NewController(t) + mockQs := mock_netceptor.NewMockQuicStreamForConn(ctrl) + bytes := []byte{4, 8, 15, 16, 23, 42} + t.Run("Returns number of bytes written in successful Write", func(t *testing.T) { + want := 6 + mockQs.EXPECT().Write(gomock.Eq(bytes)).Return(want, nil).Times(1) + conn := makeConn(t, TestConn{qs: mockQs}) + got, err := conn.Write(bytes) + if err != nil { + t.Fatalf("Write returned unexpected error %v", err) + } + if got != want { + t.Errorf("Wanted %v, got %v", want, got) + } + }) + t.Run("Returns error from unsuccessful Write", func(t *testing.T) { + wantErr := errors.New("Write error") + mockQs.EXPECT().Write(gomock.Eq(bytes)).Return(0, wantErr).Times(1) + conn := makeConn(t, TestConn{qs: mockQs}) + _, gotErr := conn.Write(bytes) + if gotErr == nil { + t.Errorf("Write did not return expected error") + } + if gotErr != wantErr { + t.Errorf("Wanted %v, got %v", wantErr, gotErr) + } + }) +} + +func TestClose(t *testing.T) { + ctrl := gomock.NewController(t) + mockQs := mock_netceptor.NewMockQuicStreamForConn(ctrl) + mockQs.EXPECT().Close().Return(nil) + conn := makeConn(t, TestConn{qs: mockQs}) + err := conn.Close() // This calls the doneOnce and closes the doneChan + // would be nice to test that the doneChan is closed + if err != nil { + t.Fatalf("conn.Close returned error %v", err) + } +} + +// These tests operate on the quic Connection. +func TestCloseConnection(t *testing.T) { + ctrl := gomock.NewController(t) + // quic Connection should be closed + mockQc := mock_netceptor.NewMockQuicConnectionForConn(ctrl) + mockQc.EXPECT().CloseWithError(quic.ApplicationErrorCode(0), gomock.Eq("normal close")).Return(nil).Times(1) + + // PacketConner should be cancelled + mockPc := mock_netceptor.NewMockPacketConner(ctrl) + mockPc.EXPECT().Cancel().Times(1) + + // The CloseConnection method logs some information to the netceptor's Logger, so mock them + mockPc.EXPECT().LocalService().Return("test-local-service").Times(1) + mockQc.EXPECT().RemoteAddr().Return(netceptor.Addr{}) + + conn := makeConn(t, TestConn{pc: mockPc, qc: mockQc}) + err := conn.CloseConnection() + if err != nil { + t.Fatalf("conn.CloseConnection returned error %v", err) + } +} + +func TestLocalAddr(t *testing.T) { + want := netceptor.Addr{} // Could mock the net interacted here rather than an empty Addr{} + ctrl := gomock.NewController(t) + mockQc := mock_netceptor.NewMockQuicConnectionForConn(ctrl) + mockQc.EXPECT().LocalAddr().Return(want).Times(1) + conn := makeConn(t, TestConn{qc: mockQc}) + got := conn.LocalAddr() + if got != want { + t.Errorf("Wanted %v, got %v", want, got) + } +} + +func TestRemoteAddr(t *testing.T) { + want := netceptor.Addr{} // Could mock the net interacted here rather than an empty Addr{} + ctrl := gomock.NewController(t) + mockQc := mock_netceptor.NewMockQuicConnectionForConn(ctrl) + mockQc.EXPECT().RemoteAddr().Return(want).Times(1) + conn := makeConn(t, TestConn{qc: mockQc}) + got := conn.RemoteAddr() + if got != want { + t.Errorf("Wanted %v, got %v", want, got) + } +} + +func TestSetDeadline(t *testing.T) { + ctrl := gomock.NewController(t) + mockQs := mock_netceptor.NewMockQuicStreamForConn(ctrl) + want := time.Now().Add(10 * time.Second) + t.Run("Returns no error after successful SetDeadline", func(t *testing.T) { + mockQs.EXPECT().SetDeadline(gomock.Eq(want)).Return(nil) + conn := makeConn(t, TestConn{qs: mockQs}) + err := conn.SetDeadline(want) + if err != nil { + t.Fatalf("conn.TestSetDeadline returned error %v", err) + } + }) + t.Run("Returns error from unsuccessful SetDeadline", func(t *testing.T) { + wantErr := errors.New("SetDeadline error") + mockQs.EXPECT().SetDeadline(gomock.Eq(want)).Return(wantErr) + conn := makeConn(t, TestConn{qs: mockQs}) + gotErr := conn.SetDeadline(want) + if gotErr != wantErr { + t.Errorf("Wanted %v, got %v", wantErr, gotErr) + } + }) +} + +func TestSetReadDeadline(t *testing.T) { + ctrl := gomock.NewController(t) + mockQs := mock_netceptor.NewMockQuicStreamForConn(ctrl) + want := time.Now().Add(10 * time.Second) + t.Run("Returns no error after successful SetReadDeadline", func(t *testing.T) { + mockQs.EXPECT().SetReadDeadline(gomock.Eq(want)).Return(nil) + conn := makeConn(t, TestConn{qs: mockQs}) + err := conn.SetReadDeadline(want) + if err != nil { + t.Fatalf("conn.SetReadDeadline returned error %v", err) + } + }) + t.Run("Returns error from unsuccessful SetReadDeadline", func(t *testing.T) { + wantErr := errors.New("SetReadDeadline error") + mockQs.EXPECT().SetReadDeadline(gomock.Eq(want)).Return(wantErr) + conn := makeConn(t, TestConn{qs: mockQs}) + gotErr := conn.SetReadDeadline(want) + if gotErr != wantErr { + t.Errorf("Wanted %v, got %v", wantErr, gotErr) + } + }) +} + +func TestSetWriteDeadline(t *testing.T) { + ctrl := gomock.NewController(t) + mockQs := mock_netceptor.NewMockQuicStreamForConn(ctrl) + want := time.Now().Add(10 * time.Second) + t.Run("Returns no error after successful SetWriteDeadline", func(t *testing.T) { + mockQs.EXPECT().SetWriteDeadline(gomock.Eq(want)).Return(nil) + conn := makeConn(t, TestConn{qs: mockQs}) + err := conn.SetWriteDeadline(want) + if err != nil { + t.Fatalf("conn.SetWriteDeadline returned error %v", err) + } + }) + t.Run("Returns error from unsuccessful SetWriteDeadline", func(t *testing.T) { + wantErr := errors.New("SetWriteDeadline error") + mockQs.EXPECT().SetWriteDeadline(gomock.Eq(want)).Return(wantErr) + conn := makeConn(t, TestConn{qs: mockQs}) + gotErr := conn.SetWriteDeadline(want) + if gotErr != wantErr { + t.Errorf("Wanted %v, got %v", wantErr, gotErr) + } + }) +} diff --git a/pkg/netceptor/mock_netceptor/conn.go b/pkg/netceptor/mock_netceptor/conn.go new file mode 100644 index 000000000..093708648 --- /dev/null +++ b/pkg/netceptor/mock_netceptor/conn.go @@ -0,0 +1,395 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: pkg/netceptor/conn.go +// +// Generated by this command: +// +// mockgen -source=pkg/netceptor/conn.go -destination=pkg/netceptor/mock_netceptor/conn.go +// + +// Package mock_netceptor is a generated GoMock package. +package mock_netceptor + +import ( + context "context" + net "net" + reflect "reflect" + time "time" + + quic "github.com/quic-go/quic-go" + gomock "go.uber.org/mock/gomock" +) + +// MockQuicStreamForConn is a mock of QuicStreamForConn interface. +type MockQuicStreamForConn struct { + ctrl *gomock.Controller + recorder *MockQuicStreamForConnMockRecorder + isgomock struct{} +} + +// MockQuicStreamForConnMockRecorder is the mock recorder for MockQuicStreamForConn. +type MockQuicStreamForConnMockRecorder struct { + mock *MockQuicStreamForConn +} + +// NewMockQuicStreamForConn creates a new mock instance. +func NewMockQuicStreamForConn(ctrl *gomock.Controller) *MockQuicStreamForConn { + mock := &MockQuicStreamForConn{ctrl: ctrl} + mock.recorder = &MockQuicStreamForConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockQuicStreamForConn) EXPECT() *MockQuicStreamForConnMockRecorder { + return m.recorder +} + +// CancelRead mocks base method. +func (m *MockQuicStreamForConn) CancelRead(arg0 quic.StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelRead", arg0) +} + +// CancelRead indicates an expected call of CancelRead. +func (mr *MockQuicStreamForConnMockRecorder) CancelRead(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockQuicStreamForConn)(nil).CancelRead), arg0) +} + +// CancelWrite mocks base method. +func (m *MockQuicStreamForConn) CancelWrite(arg0 quic.StreamErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelWrite", arg0) +} + +// CancelWrite indicates an expected call of CancelWrite. +func (mr *MockQuicStreamForConnMockRecorder) CancelWrite(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockQuicStreamForConn)(nil).CancelWrite), arg0) +} + +// Close mocks base method. +func (m *MockQuicStreamForConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockQuicStreamForConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockQuicStreamForConn)(nil).Close)) +} + +// Context mocks base method. +func (m *MockQuicStreamForConn) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockQuicStreamForConnMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQuicStreamForConn)(nil).Context)) +} + +// Read mocks base method. +func (m *MockQuicStreamForConn) Read(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockQuicStreamForConnMockRecorder) Read(p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockQuicStreamForConn)(nil).Read), p) +} + +// SetDeadline mocks base method. +func (m *MockQuicStreamForConn) SetDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockQuicStreamForConnMockRecorder) SetDeadline(t any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockQuicStreamForConn)(nil).SetDeadline), t) +} + +// SetReadDeadline mocks base method. +func (m *MockQuicStreamForConn) SetReadDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockQuicStreamForConnMockRecorder) SetReadDeadline(t any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockQuicStreamForConn)(nil).SetReadDeadline), t) +} + +// SetWriteDeadline mocks base method. +func (m *MockQuicStreamForConn) SetWriteDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockQuicStreamForConnMockRecorder) SetWriteDeadline(t any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockQuicStreamForConn)(nil).SetWriteDeadline), t) +} + +// StreamID mocks base method. +func (m *MockQuicStreamForConn) StreamID() quic.StreamID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StreamID") + ret0, _ := ret[0].(quic.StreamID) + return ret0 +} + +// StreamID indicates an expected call of StreamID. +func (mr *MockQuicStreamForConnMockRecorder) StreamID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockQuicStreamForConn)(nil).StreamID)) +} + +// Write mocks base method. +func (m *MockQuicStreamForConn) Write(p []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockQuicStreamForConnMockRecorder) Write(p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockQuicStreamForConn)(nil).Write), p) +} + +// MockQuicConnectionForConn is a mock of QuicConnectionForConn interface. +type MockQuicConnectionForConn struct { + ctrl *gomock.Controller + recorder *MockQuicConnectionForConnMockRecorder + isgomock struct{} +} + +// MockQuicConnectionForConnMockRecorder is the mock recorder for MockQuicConnectionForConn. +type MockQuicConnectionForConnMockRecorder struct { + mock *MockQuicConnectionForConn +} + +// NewMockQuicConnectionForConn creates a new mock instance. +func NewMockQuicConnectionForConn(ctrl *gomock.Controller) *MockQuicConnectionForConn { + mock := &MockQuicConnectionForConn{ctrl: ctrl} + mock.recorder = &MockQuicConnectionForConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockQuicConnectionForConn) EXPECT() *MockQuicConnectionForConnMockRecorder { + return m.recorder +} + +// AcceptStream mocks base method. +func (m *MockQuicConnectionForConn) AcceptStream(arg0 context.Context) (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptStream", arg0) + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptStream indicates an expected call of AcceptStream. +func (mr *MockQuicConnectionForConnMockRecorder) AcceptStream(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicConnectionForConn)(nil).AcceptStream), arg0) +} + +// AcceptUniStream mocks base method. +func (m *MockQuicConnectionForConn) AcceptUniStream(arg0 context.Context) (quic.ReceiveStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptUniStream", arg0) + ret0, _ := ret[0].(quic.ReceiveStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcceptUniStream indicates an expected call of AcceptUniStream. +func (mr *MockQuicConnectionForConnMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicConnectionForConn)(nil).AcceptUniStream), arg0) +} + +// CloseWithError mocks base method. +func (m *MockQuicConnectionForConn) CloseWithError(arg0 quic.ApplicationErrorCode, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseWithError indicates an expected call of CloseWithError. +func (mr *MockQuicConnectionForConnMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicConnectionForConn)(nil).CloseWithError), arg0, arg1) +} + +// ConnectionState mocks base method. +func (m *MockQuicConnectionForConn) ConnectionState() quic.ConnectionState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectionState") + ret0, _ := ret[0].(quic.ConnectionState) + return ret0 +} + +// ConnectionState indicates an expected call of ConnectionState. +func (mr *MockQuicConnectionForConnMockRecorder) ConnectionState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQuicConnectionForConn)(nil).ConnectionState)) +} + +// Context mocks base method. +func (m *MockQuicConnectionForConn) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockQuicConnectionForConnMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQuicConnectionForConn)(nil).Context)) +} + +// LocalAddr mocks base method. +func (m *MockQuicConnectionForConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockQuicConnectionForConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQuicConnectionForConn)(nil).LocalAddr)) +} + +// OpenStream mocks base method. +func (m *MockQuicConnectionForConn) OpenStream() (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStream") + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStream indicates an expected call of OpenStream. +func (mr *MockQuicConnectionForConnMockRecorder) OpenStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQuicConnectionForConn)(nil).OpenStream)) +} + +// OpenStreamSync mocks base method. +func (m *MockQuicConnectionForConn) OpenStreamSync(arg0 context.Context) (quic.Stream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenStreamSync", arg0) + ret0, _ := ret[0].(quic.Stream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenStreamSync indicates an expected call of OpenStreamSync. +func (mr *MockQuicConnectionForConnMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicConnectionForConn)(nil).OpenStreamSync), arg0) +} + +// OpenUniStream mocks base method. +func (m *MockQuicConnectionForConn) OpenUniStream() (quic.SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStream") + ret0, _ := ret[0].(quic.SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStream indicates an expected call of OpenUniStream. +func (mr *MockQuicConnectionForConnMockRecorder) OpenUniStream() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockQuicConnectionForConn)(nil).OpenUniStream)) +} + +// OpenUniStreamSync mocks base method. +func (m *MockQuicConnectionForConn) OpenUniStreamSync(arg0 context.Context) (quic.SendStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) + ret0, _ := ret[0].(quic.SendStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenUniStreamSync indicates an expected call of OpenUniStreamSync. +func (mr *MockQuicConnectionForConnMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicConnectionForConn)(nil).OpenUniStreamSync), arg0) +} + +// ReceiveDatagram mocks base method. +func (m *MockQuicConnectionForConn) ReceiveDatagram(arg0 context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveDatagram", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveDatagram indicates an expected call of ReceiveDatagram. +func (mr *MockQuicConnectionForConnMockRecorder) ReceiveDatagram(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockQuicConnectionForConn)(nil).ReceiveDatagram), arg0) +} + +// RemoteAddr mocks base method. +func (m *MockQuicConnectionForConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockQuicConnectionForConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicConnectionForConn)(nil).RemoteAddr)) +} + +// SendDatagram mocks base method. +func (m *MockQuicConnectionForConn) SendDatagram(payload []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendDatagram", payload) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendDatagram indicates an expected call of SendDatagram. +func (mr *MockQuicConnectionForConnMockRecorder) SendDatagram(payload any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockQuicConnectionForConn)(nil).SendDatagram), payload) +} diff --git a/pkg/services/mock_services/tcp_proxy.go b/pkg/services/mock_services/tcp_proxy.go new file mode 100644 index 000000000..752c330c2 --- /dev/null +++ b/pkg/services/mock_services/tcp_proxy.go @@ -0,0 +1,427 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: tcp_proxy.go + +// Package mock_services is a generated GoMock package. +package mock_services + +import ( + tls "crypto/tls" + io "io" + net "net" + reflect "reflect" + time "time" + + logger "github.com/ansible/receptor/pkg/logger" + netceptor "github.com/ansible/receptor/pkg/netceptor" + gomock "go.uber.org/mock/gomock" +) + +// MockNetcForTCPProxy is a mock of NetcForTCPProxy interface. +type MockNetcForTCPProxy struct { + ctrl *gomock.Controller + recorder *MockNetcForTCPProxyMockRecorder +} + +// MockNetcForTCPProxyMockRecorder is the mock recorder for MockNetcForTCPProxy. +type MockNetcForTCPProxyMockRecorder struct { + mock *MockNetcForTCPProxy +} + +// NewMockNetcForTCPProxy creates a new mock instance. +func NewMockNetcForTCPProxy(ctrl *gomock.Controller) *MockNetcForTCPProxy { + mock := &MockNetcForTCPProxy{ctrl: ctrl} + mock.recorder = &MockNetcForTCPProxyMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetcForTCPProxy) EXPECT() *MockNetcForTCPProxyMockRecorder { + return m.recorder +} + +// Dial mocks base method. +func (m *MockNetcForTCPProxy) Dial(node, service string, tlscfg *tls.Config) (*netceptor.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Dial", node, service, tlscfg) + ret0, _ := ret[0].(*netceptor.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Dial indicates an expected call of Dial. +func (mr *MockNetcForTCPProxyMockRecorder) Dial(node, service, tlscfg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockNetcForTCPProxy)(nil).Dial), node, service, tlscfg) +} + +// GetLogger mocks base method. +func (m *MockNetcForTCPProxy) GetLogger() *logger.ReceptorLogger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLogger") + ret0, _ := ret[0].(*logger.ReceptorLogger) + return ret0 +} + +// GetLogger indicates an expected call of GetLogger. +func (mr *MockNetcForTCPProxyMockRecorder) GetLogger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogger", reflect.TypeOf((*MockNetcForTCPProxy)(nil).GetLogger)) +} + +// ListenAndAdvertise mocks base method. +func (m *MockNetcForTCPProxy) ListenAndAdvertise(service string, tlscfg *tls.Config, tags map[string]string) (*netceptor.Listener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListenAndAdvertise", service, tlscfg, tags) + ret0, _ := ret[0].(*netceptor.Listener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListenAndAdvertise indicates an expected call of ListenAndAdvertise. +func (mr *MockNetcForTCPProxyMockRecorder) ListenAndAdvertise(service, tlscfg, tags interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListenAndAdvertise", reflect.TypeOf((*MockNetcForTCPProxy)(nil).ListenAndAdvertise), service, tlscfg, tags) +} + +// MockNetLib is a mock of NetLib interface. +type MockNetLib struct { + ctrl *gomock.Controller + recorder *MockNetLibMockRecorder +} + +// MockNetLibMockRecorder is the mock recorder for MockNetLib. +type MockNetLibMockRecorder struct { + mock *MockNetLib +} + +// NewMockNetLib creates a new mock instance. +func NewMockNetLib(ctrl *gomock.Controller) *MockNetLib { + mock := &MockNetLib{ctrl: ctrl} + mock.recorder = &MockNetLibMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetLib) EXPECT() *MockNetLibMockRecorder { + return m.recorder +} + +// Dial mocks base method. +func (m *MockNetLib) Dial(network, address string) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Dial", network, address) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Dial indicates an expected call of Dial. +func (mr *MockNetLibMockRecorder) Dial(network, address interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockNetLib)(nil).Dial), network, address) +} + +// Listen mocks base method. +func (m *MockNetLib) Listen(network, address string) (net.Listener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Listen", network, address) + ret0, _ := ret[0].(net.Listener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Listen indicates an expected call of Listen. +func (mr *MockNetLibMockRecorder) Listen(network, address interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Listen", reflect.TypeOf((*MockNetLib)(nil).Listen), network, address) +} + +// MockTLSLib is a mock of TLSLib interface. +type MockTLSLib struct { + ctrl *gomock.Controller + recorder *MockTLSLibMockRecorder +} + +// MockTLSLibMockRecorder is the mock recorder for MockTLSLib. +type MockTLSLibMockRecorder struct { + mock *MockTLSLib +} + +// NewMockTLSLib creates a new mock instance. +func NewMockTLSLib(ctrl *gomock.Controller) *MockTLSLib { + mock := &MockTLSLib{ctrl: ctrl} + mock.recorder = &MockTLSLibMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTLSLib) EXPECT() *MockTLSLibMockRecorder { + return m.recorder +} + +// Dial mocks base method. +func (m *MockTLSLib) Dial(network, addr string, config *tls.Config) (*tls.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Dial", network, addr, config) + ret0, _ := ret[0].(*tls.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Dial indicates an expected call of Dial. +func (mr *MockTLSLibMockRecorder) Dial(network, addr, config interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockTLSLib)(nil).Dial), network, addr, config) +} + +// NewListener mocks base method. +func (m *MockTLSLib) NewListener(inner net.Listener, config *tls.Config) net.Listener { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListener", inner, config) + ret0, _ := ret[0].(net.Listener) + return ret0 +} + +// NewListener indicates an expected call of NewListener. +func (mr *MockTLSLibMockRecorder) NewListener(inner, config interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListener", reflect.TypeOf((*MockTLSLib)(nil).NewListener), inner, config) +} + +// MockNetListenerTCP is a mock of NetListenerTCP interface. +type MockNetListenerTCP struct { + ctrl *gomock.Controller + recorder *MockNetListenerTCPMockRecorder +} + +// MockNetListenerTCPMockRecorder is the mock recorder for MockNetListenerTCP. +type MockNetListenerTCPMockRecorder struct { + mock *MockNetListenerTCP +} + +// NewMockNetListenerTCP creates a new mock instance. +func NewMockNetListenerTCP(ctrl *gomock.Controller) *MockNetListenerTCP { + mock := &MockNetListenerTCP{ctrl: ctrl} + mock.recorder = &MockNetListenerTCPMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetListenerTCP) EXPECT() *MockNetListenerTCPMockRecorder { + return m.recorder +} + +// Accept mocks base method. +func (m *MockNetListenerTCP) Accept() (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept") + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept. +func (mr *MockNetListenerTCPMockRecorder) Accept() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockNetListenerTCP)(nil).Accept)) +} + +// Addr mocks base method. +func (m *MockNetListenerTCP) Addr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// Addr indicates an expected call of Addr. +func (mr *MockNetListenerTCPMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockNetListenerTCP)(nil).Addr)) +} + +// Close mocks base method. +func (m *MockNetListenerTCP) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockNetListenerTCPMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNetListenerTCP)(nil).Close)) +} + +// MockUtilsLib is a mock of UtilsLib interface. +type MockUtilsLib struct { + ctrl *gomock.Controller + recorder *MockUtilsLibMockRecorder +} + +// MockUtilsLibMockRecorder is the mock recorder for MockUtilsLib. +type MockUtilsLibMockRecorder struct { + mock *MockUtilsLib +} + +// NewMockUtilsLib creates a new mock instance. +func NewMockUtilsLib(ctrl *gomock.Controller) *MockUtilsLib { + mock := &MockUtilsLib{ctrl: ctrl} + mock.recorder = &MockUtilsLibMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockUtilsLib) EXPECT() *MockUtilsLibMockRecorder { + return m.recorder +} + +// BridgeConns mocks base method. +func (m *MockUtilsLib) BridgeConns(c1 io.ReadWriteCloser, c1Name string, c2 io.ReadWriteCloser, c2Name string, logger *logger.ReceptorLogger) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BridgeConns", c1, c1Name, c2, c2Name, logger) +} + +// BridgeConns indicates an expected call of BridgeConns. +func (mr *MockUtilsLibMockRecorder) BridgeConns(c1, c1Name, c2, c2Name, logger interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BridgeConns", reflect.TypeOf((*MockUtilsLib)(nil).BridgeConns), c1, c1Name, c2, c2Name, logger) +} + +// MockTCPConn is a mock of TCPConn interface. +type MockTCPConn struct { + ctrl *gomock.Controller + recorder *MockTCPConnMockRecorder +} + +// MockTCPConnMockRecorder is the mock recorder for MockTCPConn. +type MockTCPConnMockRecorder struct { + mock *MockTCPConn +} + +// NewMockTCPConn creates a new mock instance. +func NewMockTCPConn(ctrl *gomock.Controller) *MockTCPConn { + mock := &MockTCPConn{ctrl: ctrl} + mock.recorder = &MockTCPConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTCPConn) EXPECT() *MockTCPConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockTCPConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockTCPConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTCPConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockTCPConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockTCPConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockTCPConn)(nil).LocalAddr)) +} + +// Read mocks base method. +func (m *MockTCPConn) Read(b []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", b) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockTCPConnMockRecorder) Read(b interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockTCPConn)(nil).Read), b) +} + +// RemoteAddr mocks base method. +func (m *MockTCPConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockTCPConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockTCPConn)(nil).RemoteAddr)) +} + +// SetDeadline mocks base method. +func (m *MockTCPConn) SetDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockTCPConnMockRecorder) SetDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockTCPConn)(nil).SetDeadline), t) +} + +// SetReadDeadline mocks base method. +func (m *MockTCPConn) SetReadDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockTCPConnMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockTCPConn)(nil).SetReadDeadline), t) +} + +// SetWriteDeadline mocks base method. +func (m *MockTCPConn) SetWriteDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockTCPConnMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockTCPConn)(nil).SetWriteDeadline), t) +} + +// Write mocks base method. +func (m *MockTCPConn) Write(b []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", b) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockTCPConnMockRecorder) Write(b interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockTCPConn)(nil).Write), b) +} diff --git a/pkg/services/mock_services/udp_proxy.go b/pkg/services/mock_services/udp_proxy.go index 218681370..98d5dde4e 100644 --- a/pkg/services/mock_services/udp_proxy.go +++ b/pkg/services/mock_services/udp_proxy.go @@ -1,10 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. // Source: pkg/services/udp_proxy.go -// -// Generated by this command: -// -// mockgen -source=pkg/services/udp_proxy.go -destination=pkg/services/mock_services/udp_proxy.go -// // Package mock_services is a generated GoMock package. package mock_services @@ -21,7 +16,6 @@ import ( type MockNetcForUDPProxy struct { ctrl *gomock.Controller recorder *MockNetcForUDPProxyMockRecorder - isgomock struct{} } // MockNetcForUDPProxyMockRecorder is the mock recorder for MockNetcForUDPProxy. @@ -65,7 +59,7 @@ func (m *MockNetcForUDPProxy) ListenPacket(service string) (netceptor.PacketConn } // ListenPacket indicates an expected call of ListenPacket. -func (mr *MockNetcForUDPProxyMockRecorder) ListenPacket(service any) *gomock.Call { +func (mr *MockNetcForUDPProxyMockRecorder) ListenPacket(service interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListenPacket", reflect.TypeOf((*MockNetcForUDPProxy)(nil).ListenPacket), service) } @@ -80,7 +74,7 @@ func (m *MockNetcForUDPProxy) ListenPacketAndAdvertise(service string, tags map[ } // ListenPacketAndAdvertise indicates an expected call of ListenPacketAndAdvertise. -func (mr *MockNetcForUDPProxyMockRecorder) ListenPacketAndAdvertise(service, tags any) *gomock.Call { +func (mr *MockNetcForUDPProxyMockRecorder) ListenPacketAndAdvertise(service, tags interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListenPacketAndAdvertise", reflect.TypeOf((*MockNetcForUDPProxy)(nil).ListenPacketAndAdvertise), service, tags) } @@ -94,7 +88,7 @@ func (m *MockNetcForUDPProxy) NewAddr(node, service string) netceptor.Addr { } // NewAddr indicates an expected call of NewAddr. -func (mr *MockNetcForUDPProxyMockRecorder) NewAddr(node, service any) *gomock.Call { +func (mr *MockNetcForUDPProxyMockRecorder) NewAddr(node, service interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAddr", reflect.TypeOf((*MockNetcForUDPProxy)(nil).NewAddr), node, service) } diff --git a/pkg/services/tcp_proxy.go b/pkg/services/tcp_proxy.go index 5d73c3b7e..f1179e838 100644 --- a/pkg/services/tcp_proxy.go +++ b/pkg/services/tcp_proxy.go @@ -3,22 +3,85 @@ package services import ( "crypto/tls" "fmt" + "io" "net" "strconv" + "github.com/ansible/receptor/pkg/logger" "github.com/ansible/receptor/pkg/netceptor" "github.com/ansible/receptor/pkg/utils" "github.com/ghjm/cmdline" "github.com/spf13/viper" ) +//go:generate mockgen -package mock_services -source=tcp_proxy.go -destination=mock_services/tcp_proxy.go + +type NetcForTCPProxy interface { + GetLogger() *logger.ReceptorLogger + Dial(node string, service string, tlscfg *tls.Config) (*netceptor.Conn, error) + ListenAndAdvertise(service string, tlscfg *tls.Config, tags map[string]string) (*netceptor.Listener, error) +} + +// Interface for the net library to generate stubs with mockgen. +type NetLib interface { + Listen(network string, address string) (net.Listener, error) + Dial(network string, address string) (net.Conn, error) +} + +type NetTCPWrapper struct{} + +func (n *NetTCPWrapper) Listen(network string, address string) (net.Listener, error) { + return net.Listen(network, address) +} + +func (n *NetTCPWrapper) Dial(network string, address string) (net.Conn, error) { + return net.Dial(network, address) +} + +// Interface for the tls library to generate stubs with mockgen. +type TLSLib interface { + NewListener(inner net.Listener, config *tls.Config) net.Listener + Dial(network string, addr string, config *tls.Config) (*tls.Conn, error) +} + +type TLSTCPWrapper struct{} + +func (n *TLSTCPWrapper) NewListener(inner net.Listener, config *tls.Config) net.Listener { + return tls.NewListener(inner, config) +} + +func (n *TLSTCPWrapper) Dial(network string, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.Dial(network, addr, config) +} + +// Interface for the Net Listener to generate stubs with mockgen. +type NetListenerTCP interface { + net.Listener +} + +// Interface for the utils package to generate stubs with mockgen. +type UtilsLib interface { + BridgeConns(c1 io.ReadWriteCloser, c1Name string, c2 io.ReadWriteCloser, c2Name string, logger *logger.ReceptorLogger) +} + +type UtilsTCPWrapper struct{} + +func (u *UtilsTCPWrapper) BridgeConns(c1 io.ReadWriteCloser, c1Name string, c2 io.ReadWriteCloser, c2Name string, logger *logger.ReceptorLogger) { + utils.BridgeConns(c1, c1Name, c2, c2Name, logger) +} + +// Interface to mock the Connection object returned from Accept. +type TCPConn interface { + net.Conn +} + // TCPProxyServiceInbound listens on a TCP port and forwards the connection over the Receptor network. -func TCPProxyServiceInbound(s *netceptor.Netceptor, host string, port int, tlsServer *tls.Config, - node string, rservice string, tlsClient *tls.Config, +func TCPProxyServiceInbound(s NetcForTCPProxy, host string, port int, tlsServer *tls.Config, + node string, rservice string, tlsClient *tls.Config, netTCP NetLib, tlsTCP TLSLib, utilsTCP UtilsLib, ) error { - tli, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) + tli, err := netTCP.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) if tlsServer != nil { - tli = tls.NewListener(tli, tlsServer) + tli = tlsTCP.NewListener(tli, tlsServer) } if err != nil { return fmt.Errorf("error listening on TCP: %s", err) @@ -27,17 +90,17 @@ func TCPProxyServiceInbound(s *netceptor.Netceptor, host string, port int, tlsSe for { tc, err := tli.Accept() if err != nil { - s.Logger.Error("Error accepting TCP connection: %s\n", err) + s.GetLogger().Error("error accepting TCP connection: %s\n", err) return } qc, err := s.Dial(node, rservice, tlsClient) if err != nil { - s.Logger.Error("Error connecting on Receptor network: %s\n", err) + s.GetLogger().Error("error connecting on Receptor network: %s\n", err) continue } - go utils.BridgeConns(tc, "tcp service", qc, "receptor connection", s.Logger) + go utilsTCP.BridgeConns(tc, "tcp service", qc, "receptor connection", s.GetLogger()) } }() @@ -45,8 +108,8 @@ func TCPProxyServiceInbound(s *netceptor.Netceptor, host string, port int, tlsSe } // TCPProxyServiceOutbound listens on the Receptor network and forwards the connection via TCP. -func TCPProxyServiceOutbound(s *netceptor.Netceptor, service string, tlsServer *tls.Config, - address string, tlsClient *tls.Config, +func TCPProxyServiceOutbound(s NetcForTCPProxy, service string, tlsServer *tls.Config, + address string, tlsClient *tls.Config, netTCP NetLib, tlsTCP TLSLib, utilsTCP UtilsLib, ) error { qli, err := s.ListenAndAdvertise(service, tlsServer, map[string]string{ "type": "TCP Proxy", @@ -59,22 +122,22 @@ func TCPProxyServiceOutbound(s *netceptor.Netceptor, service string, tlsServer * for { qc, err := qli.Accept() if err != nil { - s.Logger.Error("Error accepting connection on Receptor network: %s\n", err) + s.GetLogger().Error("Error accepting connection on Receptor network: %s\n", err) return } var tc net.Conn if tlsClient == nil { - tc, err = net.Dial("tcp", address) + tc, err = netTCP.Dial("tcp", address) } else { - tc, err = tls.Dial("tcp", address, tlsClient) + tc, err = tlsTCP.Dial("tcp", address, tlsClient) } if err != nil { - s.Logger.Error("Error connecting via TCP: %s\n", err) + s.GetLogger().Error("Error connecting via TCP: %s\n", err) continue } - go utils.BridgeConns(qc, "receptor service", tc, "tcp connection", s.Logger) + go utilsTCP.BridgeConns(qc, "receptor service", tc, "tcp connection", s.GetLogger()) } }() @@ -104,7 +167,7 @@ func (cfg TCPProxyInboundCfg) Run() error { } return TCPProxyServiceInbound(netceptor.MainInstance, cfg.BindAddr, cfg.Port, TLSServerConfig, - cfg.RemoteNode, cfg.RemoteService, tlsClientCfg) + cfg.RemoteNode, cfg.RemoteService, tlsClientCfg, &NetTCPWrapper{}, &TLSTCPWrapper{}, &UtilsTCPWrapper{}) } // tcpProxyOutboundCfg is the cmdline configuration object for a TCP outbound proxy. @@ -131,7 +194,7 @@ func (cfg TCPProxyOutboundCfg) Run() error { return err } - return TCPProxyServiceOutbound(netceptor.MainInstance, cfg.Service, TLSServerConfig, cfg.Address, tlsClientCfg) + return TCPProxyServiceOutbound(netceptor.MainInstance, cfg.Service, TLSServerConfig, cfg.Address, tlsClientCfg, &NetTCPWrapper{}, &TLSTCPWrapper{}, &UtilsTCPWrapper{}) } func init() { diff --git a/pkg/services/tcp_proxy_test.go b/pkg/services/tcp_proxy_test.go new file mode 100644 index 000000000..f0d9f59c7 --- /dev/null +++ b/pkg/services/tcp_proxy_test.go @@ -0,0 +1,318 @@ +package services + +import ( + "context" + "crypto/tls" + "errors" + "testing" + + "github.com/ansible/receptor/pkg/logger" + "github.com/ansible/receptor/pkg/netceptor" + "github.com/ansible/receptor/pkg/services/mock_services" + "go.uber.org/mock/gomock" +) + +func setUpTCPMocks(ctrl *gomock.Controller) (*mock_services.MockNetcForTCPProxy, *mock_services.MockNetLib, *mock_services.MockTLSLib, *mock_services.MockNetListenerTCP, *mock_services.MockUtilsLib, *mock_services.MockTCPConn) { + mockNetceptor := mock_services.NewMockNetcForTCPProxy(ctrl) + mockNetLib := mock_services.NewMockNetLib(ctrl) + mockTLSLib := mock_services.NewMockTLSLib(ctrl) + mockNetListener := mock_services.NewMockNetListenerTCP(ctrl) + mockUtilsLib := mock_services.NewMockUtilsLib(ctrl) + mockTCPConn := mock_services.NewMockTCPConn(ctrl) + logger := logger.NewReceptorLogger("") + mockNetceptor.EXPECT().GetLogger().AnyTimes().Return(logger) + + return mockNetceptor, mockNetLib, mockTLSLib, mockNetListener, mockUtilsLib, mockTCPConn +} + +func TestTCPProxyServiceInbound(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var mockNetceptor *mock_services.MockNetcForTCPProxy + var mockNetLib *mock_services.MockNetLib + var mockTLSLib *mock_services.MockTLSLib + var mockNetListener *mock_services.MockNetListenerTCP + var mockUtilsLib *mock_services.MockUtilsLib + var mockTCPConn *mock_services.MockTCPConn + + type testCoverageItem struct { + name string + host string + port int + expectError bool + expectedErrorMessage string + node string + service string + tlsServerConfig *tls.Config + tlsClientConfig *tls.Config + calls func() + } + testCases := []testCoverageItem{ + { + name: "Fail to listen to input connections", + expectError: true, + tlsServerConfig: &tls.Config{}, + expectedErrorMessage: "error listening on TCP: failed to stablish a connection", + calls: func() { + mockNetLib.EXPECT().Listen(gomock.Any(), gomock.Any()).Return(nil, errors.New("failed to stablish a connection")).Times(1) + mockTLSLib.EXPECT().NewListener(gomock.Any(), gomock.Any()).Return(mockNetListener).Times(1) + }, + }, + { + name: "Fail to listen to input connections with tls config set", + expectError: true, + tlsServerConfig: nil, + expectedErrorMessage: "error listening on TCP: failed to stablish a connection", + calls: func() { + mockNetLib.EXPECT().Listen(gomock.Any(), gomock.Any()).Return(nil, errors.New("failed to stablish a connection")).Times(1) + }, + }, + { + name: "Fail to accept incoming connections to the listener", + tlsServerConfig: nil, + calls: func() { + mockNetLib.EXPECT().Listen(gomock.Any(), gomock.Any()).Return(mockNetListener, nil).Times(1) + mockNetListener.EXPECT().Accept().Return(nil, errors.New("failed to accept incoming connection")).AnyTimes() + }, + }, + { + name: "Fail to dial to the receptor network after accepting an inbound connection", + tlsServerConfig: nil, + calls: func() { + mockNetLib.EXPECT().Listen(gomock.Any(), gomock.Any()).Return(mockNetListener, nil).Times(1) + mockNetListener.EXPECT().Accept().Return(mockTCPConn, nil).AnyTimes() + mockNetceptor.EXPECT().Dial(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("failed to connect to Receptor network")).AnyTimes() + }, + }, + { + name: "Bridge connections after accepting inbound TCP connection", + tlsServerConfig: nil, + calls: func() { + mockNetLib.EXPECT().Listen(gomock.Any(), gomock.Any()).Return(mockNetListener, nil).Times(1) + mockNetListener.EXPECT().Accept().Return(mockTCPConn, nil).AnyTimes() + mockNetceptor.EXPECT().Dial(gomock.Any(), gomock.Any(), gomock.Any()).Return(&netceptor.Conn{}, nil).AnyTimes() + mockUtilsLib.EXPECT().BridgeConns(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockNetceptor, mockNetLib, mockTLSLib, mockNetListener, mockUtilsLib, mockTCPConn = setUpTCPMocks(ctrl) + tc.calls() + err := TCPProxyServiceInbound(mockNetceptor, tc.host, tc.port, tc.tlsServerConfig, tc.node, tc.service, tc.tlsClientConfig, mockNetLib, mockTLSLib, mockUtilsLib) + if tc.expectError { + if err == nil { + t.Errorf("TCPProxyServiceInbound failed to raise error") + } else if tc.expectedErrorMessage != err.Error() { + t.Errorf("TCPProxyServiceInbound didn't return the correct error message") + } + } else if err != nil { + t.Errorf("TCPProxyServiceInbound unexpected case error") + } + }) + } +} + +func TestTCPProxyServiceOutbound(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var mockNetceptor *mock_services.MockNetcForTCPProxy + var mockNetLib *mock_services.MockNetLib + var mockTLSLib *mock_services.MockTLSLib + var mockNetListener *mock_services.MockNetListenerTCP + var mockUtilsLib *mock_services.MockUtilsLib + + type testCoverageItem struct { + name string + expectError bool + expectedErrorMessage string + service string + address string + tlsClientConfig *tls.Config + calls func() + } + testCases := []testCoverageItem{ + { + name: "Fail to listen and advertise connection", + expectError: true, + expectedErrorMessage: "error listening on Receptor network: failed to stablish a connection", + calls: func() { + mockNetceptor.EXPECT().ListenAndAdvertise(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("failed to stablish a connection")).Times(1) + }, + }, + { + name: "Fail to accept input connections", + calls: func() { + mockNetceptor.EXPECT().ListenAndAdvertise(gomock.Any(), gomock.Any(), gomock.Any()).Return(&netceptor.Listener{}, nil).Times(1) + mockNetListener.EXPECT().Accept().Return(nil, errors.New("connection acceptance failed")).AnyTimes() + }, + }, + { + name: "Fail to dial through non-TLS TCP connection", + tlsClientConfig: nil, + calls: func() { + mockNetceptor.EXPECT().ListenAndAdvertise(gomock.Any(), gomock.Any(), gomock.Any()).Return(&netceptor.Listener{}, nil).Times(1) + mockNetListener.EXPECT().Accept().Return(&netceptor.Conn{}, nil).AnyTimes() + mockNetLib.EXPECT().Dial(gomock.Any(), gomock.Any()).Return(nil, errors.New("non-TLS TCP dial failed")).AnyTimes() + }, + }, + { + name: "Fail to dial through TLS TCP connection", + tlsClientConfig: &tls.Config{}, + calls: func() { + mockNetceptor.EXPECT().ListenAndAdvertise(gomock.Any(), gomock.Any(), gomock.Any()).Return(&netceptor.Listener{}, nil).Times(1) + mockNetListener.EXPECT().Accept().Return(&netceptor.Conn{}, nil).AnyTimes() + mockTLSLib.EXPECT().Dial(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("TLS TCP dial failed")).AnyTimes() + }, + }, + { + name: "Complete connection bridge after successful non-TLS connection", + calls: func() { + mockNetceptor.EXPECT().ListenAndAdvertise(gomock.Any(), gomock.Any(), gomock.Any()).Return(&netceptor.Listener{}, nil).Times(1) + mockNetListener.EXPECT().Accept().Return(&netceptor.Conn{}, nil).AnyTimes() + mockTLSLib.EXPECT().Dial(gomock.Any(), gomock.Any(), gomock.Any()).Return(&tls.Conn{}, nil).AnyTimes() + mockUtilsLib.EXPECT().BridgeConns(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockNetceptor, mockNetLib, mockTLSLib, mockNetListener, mockUtilsLib, _ = setUpTCPMocks(ctrl) + tc.calls() + err := TCPProxyServiceOutbound(mockNetceptor, tc.service, &tls.Config{}, tc.address, tc.tlsClientConfig, mockNetLib, mockTLSLib, mockUtilsLib) + if tc.expectError { + if err == nil { + t.Errorf("TCPProxyServiceOutbound case failed to raise error") + } else if tc.expectedErrorMessage != err.Error() { + t.Errorf("TCPProxyServiceOutbound didn't return the correct error message") + } + } else if err != nil { + t.Errorf("TCPProxyServiceOutbound unexpected case error") + } + }) + } +} + +func TestTCPProxyInboundCfgRun(t *testing.T) { + type testCoverageItem struct { + name string + expectError bool + expectedErrorMessage string + configObj TCPProxyInboundCfg + } + + testCases := []testCoverageItem{ + { + name: "Required parameters set no errors raised", + configObj: TCPProxyInboundCfg{ + Port: 8000, + RemoteNode: "", + RemoteService: "", + }, + }, + { + name: "Required parameters set wrong TLS Client Config", + expectError: true, + expectedErrorMessage: "unknown TLS config gibberish", + configObj: TCPProxyInboundCfg{ + Port: 8000, + RemoteNode: "", + RemoteService: "", + TLSClient: "gibberish", + }, + }, + { + name: "Required parameters set wrong TLS Server Config", + expectError: true, + expectedErrorMessage: "unknown TLS config gibberish", + configObj: TCPProxyInboundCfg{ + Port: 8000, + RemoteNode: "", + RemoteService: "", + TLSServer: "gibberish", + }, + }, + } + netceptor.MainInstance = netceptor.New(context.Background(), "test_tcp_proxy_inbound_cfg_run") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.configObj.Run() + if tc.expectError { + if err == nil { + t.Errorf("Test case failed to raise error") + } else if tc.expectedErrorMessage != err.Error() { + t.Errorf("Test expected error message: '%s', but got: '%s'", tc.expectedErrorMessage, err.Error()) + } + } else if err != nil { + t.Errorf("This test case wasn't expected to return an error: '%s'", err.Error()) + } + }) + } +} + +func TestTCPProxyOutboundCfgRun(t *testing.T) { + type testCoverageItem struct { + name string + expectError bool + expectedErrorMessage string + configObj TCPProxyOutboundCfg + } + + testCases := []testCoverageItem{ + { + name: "Required parameters set no errors raised", + configObj: TCPProxyOutboundCfg{ + Service: "", + Address: "0.0.0.0:8000", + }, + }, + { + name: "Required parameters set wrong TLS Server Config", + expectError: true, + expectedErrorMessage: "unknown TLS config gibberish", + configObj: TCPProxyOutboundCfg{ + Service: "", + Address: "0.0.0.0:8000", + TLSServer: "gibberish", + }, + }, + { + name: "Required parameters set missing port in Address", + expectError: true, + expectedErrorMessage: "address 0.0.0.0: missing port in address", + configObj: TCPProxyOutboundCfg{ + Service: "", + Address: "0.0.0.0", + }, + }, + { + name: "Required parameters set wrong TLS Client Config", + expectError: true, + expectedErrorMessage: "unknown TLS config gibberish", + configObj: TCPProxyOutboundCfg{ + Service: "", + Address: "0.0.0.0:8000", + TLSClient: "gibberish", + }, + }, + } + netceptor.MainInstance = netceptor.New(context.Background(), "test_tcp_proxy_outbound_cfg_run") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.configObj.Run() + if tc.expectError { + if err == nil { + t.Errorf("Test case failed to raise error") + } else if tc.expectedErrorMessage != err.Error() { + t.Errorf("Test expected error message: '%s', but got: '%s'", tc.expectedErrorMessage, err.Error()) + } + } else if err != nil { + t.Errorf("This test case wasn't expected to return an error: '%s'", err.Error()) + } + }) + } +} diff --git a/pkg/services/udp_proxy.go b/pkg/services/udp_proxy.go index a224e6926..c74ef8c12 100644 --- a/pkg/services/udp_proxy.go +++ b/pkg/services/udp_proxy.go @@ -19,17 +19,17 @@ type NetcForUDPProxy interface { ListenPacketAndAdvertise(service string, tags map[string]string) (netceptor.PacketConner, error) } -type Net struct{} +type NetUDPWrapper struct{} -func (n *Net) ResolveUDPAddr(network string, address string) (*net.UDPAddr, error) { +func (n *NetUDPWrapper) ResolveUDPAddr(network string, address string) (*net.UDPAddr, error) { return net.ResolveUDPAddr(network, address) } -func (n *Net) ListenUDP(network string, laddr *net.UDPAddr) (net_interface.UDPConnInterface, error) { +func (n *NetUDPWrapper) ListenUDP(network string, laddr *net.UDPAddr) (net_interface.UDPConnInterface, error) { return net.ListenUDP(network, laddr) } -func (n *Net) DialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (net_interface.UDPConnInterface, error) { +func (n *NetUDPWrapper) DialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (net_interface.UDPConnInterface, error) { return net.DialUDP(network, laddr, raddr) } @@ -205,7 +205,7 @@ type UDPProxyInboundCfg struct { func (cfg UDPProxyInboundCfg) Run() error { netceptor.MainInstance.Logger.Debug("Running UDP inbound proxy service %v\n", cfg) - return UDPProxyServiceInbound(netceptor.MainInstance, cfg.BindAddr, cfg.Port, cfg.RemoteNode, cfg.RemoteService, &Net{}) + return UDPProxyServiceInbound(netceptor.MainInstance, cfg.BindAddr, cfg.Port, cfg.RemoteNode, cfg.RemoteService, &NetUDPWrapper{}) } // udpProxyOutboundCfg is the cmdline configuration object for a UDP outbound proxy. @@ -218,7 +218,7 @@ type UDPProxyOutboundCfg struct { func (cfg UDPProxyOutboundCfg) Run() error { netceptor.MainInstance.Logger.Debug("Running UDP outbound proxy service %s\n", cfg) - return UDPProxyServiceOutbound(netceptor.MainInstance, cfg.Service, cfg.Address, &Net{}) + return UDPProxyServiceOutbound(netceptor.MainInstance, cfg.Service, cfg.Address, &NetUDPWrapper{}) } func init() { diff --git a/pkg/tickrunner/tickrunner.go b/pkg/tickrunner/tickrunner.go index 9a3a904af..bbb1b0516 100644 --- a/pkg/tickrunner/tickrunner.go +++ b/pkg/tickrunner/tickrunner.go @@ -8,7 +8,7 @@ import ( // Run runs a task at a given periodic interval, or as requested over a channel. // If many requests come in close to the same time, only run the task once. // Callers can ask for the task to be run within a given amount of time, which -// overrides defaultReqDelay. Sending a zero to the channel runs it immediately. +// overrides defaultReqDelay. Sending a zero to the channel runs it after defaukltReqDelay. func Run(ctx context.Context, f func(), periodicInterval time.Duration, defaultReqDelay time.Duration) chan time.Duration { runChan := make(chan time.Duration) go func() { diff --git a/pkg/tickrunner/tickrunner_test.go b/pkg/tickrunner/tickrunner_test.go new file mode 100644 index 000000000..dab151112 --- /dev/null +++ b/pkg/tickrunner/tickrunner_test.go @@ -0,0 +1,121 @@ +package tickrunner + +import ( + "context" + "sync" + "testing" + "time" +) + +// SafeCounter is safe to use concurrently. +type SafeCounter struct { + mu sync.Mutex + v int +} + +// Inc increments the counter for the given key. +func (c *SafeCounter) Inc() { + c.mu.Lock() + // Lock so only one goroutine at a time can access the map c.v. + c.v++ + c.mu.Unlock() +} + +// Value returns the current value of the counter for the given key. +func (c *SafeCounter) Value() int { + c.mu.Lock() + // Lock so only one goroutine at a time can access the map c.v. + defer c.mu.Unlock() + + return c.v +} + +func TestRun(t *testing.T) { + type testCase struct { + name string + requestTime int + requestCount int + periodicInterval time.Duration + defaultReqDelay time.Duration + expectedRunCount int + /* This value must take into account the function's algorithm + and the other two time.Duration values */ + waitForBeforeChecks time.Duration + } + + tests := []testCase{ + { + name: "Run only with default periodicInterval once", + periodicInterval: time.Duration(2) * time.Second, + // Run can only execute one request with the set periodicInterval. + expectedRunCount: 1, + waitForBeforeChecks: time.Duration(3) * time.Second, + }, + { + name: "Run only with default periodicInterval more than once", + periodicInterval: time.Duration(2) * time.Second, + // Run can only execute one request with the set periodicInterval. + expectedRunCount: 3, + waitForBeforeChecks: time.Duration(7) * time.Second, + }, + { + name: "Run request inmediately", + requestCount: 1, + requestTime: 0, + /* Setting this to a high value so it doesn't run at all + with the default periodicInterval. We only want to test + for the requests sent. */ + periodicInterval: time.Duration(300) * time.Second, + defaultReqDelay: time.Duration(1) * time.Second, + // Run can only execute one request with the set periodicInterval. + expectedRunCount: 1, + waitForBeforeChecks: time.Duration(2) * time.Second, + }, + { + name: "Run sending some requests overrides default periodicInterval", + requestCount: 3, + requestTime: 2, + /* Setting this to a high value so it doesn't run at all + with the default periodicInterval. We only want to test + for the requests sent. */ + periodicInterval: time.Duration(300) * time.Second, + defaultReqDelay: time.Duration(1) * time.Second, + /* Due to the design of the test itself, the requests get sent + into the channel back to back with no time in between them + This expectedRunCount is the correct value since only the + first one will be run. */ + expectedRunCount: 1, + waitForBeforeChecks: time.Duration(3) * time.Second, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx, ctxCancel := context.WithCancel(context.Background()) + defer ctxCancel() + + /* Create a counter value and increment function to keep track of + inner function calls count. */ + runCounter := SafeCounter{} + runFunction := func() { runCounter.Inc() } + runChan := Run(ctx, runFunction, tc.periodicInterval, tc.defaultReqDelay) + + /* Send the required requests through the returned channel + These are inmediate back-to-back requests + Following the function's logic, if there are too many requests coming in + simultaneously, these should be batched since only the oldest one is taken + into account before the set time passes. */ + for i := 0; i < tc.requestCount; i++ { + runChan <- time.Duration(tc.requestTime) * time.Second + } + + /* Since this function is time-based, lets wait for a calculated time + before asserting any value so we avoid race conditions with + non-blocking code. */ + time.Sleep(tc.waitForBeforeChecks) + if runCounter.Value() != tc.expectedRunCount { + t.Errorf("Run count: %d, Expected number of runs: %d", runCounter.Value(), tc.expectedRunCount) + } + }) + } +} diff --git a/pkg/workceptor/kubernetes.go b/pkg/workceptor/kubernetes.go index a1f9d2d82..514fa80a7 100644 --- a/pkg/workceptor/kubernetes.go +++ b/pkg/workceptor/kubernetes.go @@ -41,19 +41,20 @@ import ( // KubeUnit implements the WorkUnit interface. type KubeUnit struct { BaseWorkUnitForWorkUnit - authMethod string - streamMethod string - baseParams string - allowRuntimeAuth bool - allowRuntimeCommand bool - allowRuntimeParams bool - allowRuntimePod bool - deletePodOnRestart bool - namePrefix string - config *rest.Config - clientset *kubernetes.Clientset - pod *corev1.Pod - podPendingTimeout time.Duration + KubeAPIWrapperInstance KubeAPIer + authMethod string + streamMethod string + baseParams string + allowRuntimeAuth bool + allowRuntimeCommand bool + allowRuntimeParams bool + allowRuntimePod bool + deletePodOnRestart bool + namePrefix string + config *rest.Config + clientset *kubernetes.Clientset + Pod *corev1.Pod + podPendingTimeout time.Duration } // kubeExtraData is the content of the ExtraData JSON field for a Kubernetes worker. @@ -167,12 +168,6 @@ func (ku KubeAPIWrapper) NewFakeAlwaysRateLimiter() flowcontrol.RateLimiter { return flowcontrol.NewFakeAlwaysRateLimiter() } -// KubeAPIWrapperInstance is a package level var that wraps all required kubernetes API calls. -// It is instantiated in the NewkubeWorker function and available throughout the package. -var KubeAPIWrapperInstance KubeAPIer - -var KubeAPIWrapperLock sync.Mutex - // ErrPodCompleted is returned when pod has already completed before we could attach. var ErrPodCompleted = fmt.Errorf("pod ran to completion") @@ -183,11 +178,11 @@ var ErrPodFailed = fmt.Errorf("pod failed to start") var ErrImagePullBackOff = fmt.Errorf("container failed to start") // podRunningAndReady is a completion criterion for pod ready to be attached to. -func podRunningAndReady() func(event watch.Event) (bool, error) { +func podRunningAndReady(kw KubeUnit) func(event watch.Event) (bool, error) { imagePullBackOffRetries := 3 inner := func(event watch.Event) (bool, error) { if event.Type == watch.Deleted { - return false, KubeAPIWrapperInstance.NewNotFound(schema.GroupResource{Resource: "pods"}, "") + return false, kw.KubeAPIWrapperInstance.NewNotFound(schema.GroupResource{Resource: "pods"}, "") } if t, ok := event.Object.(*corev1.Pod); ok { switch t.Status.Phase { @@ -251,8 +246,8 @@ func GetTimeoutOpenLogstream(kw *KubeUnit) int { func (kw *KubeUnit) kubeLoggingConnectionHandler(timestamps bool, sinceTime time.Time) (io.ReadCloser, error) { var logStream io.ReadCloser var err error - podNamespace := kw.pod.Namespace - podName := kw.pod.Name + podNamespace := kw.Pod.Namespace + podName := kw.Pod.Name podOptions := &corev1.PodLogOptions{ Container: "worker", Follow: true, @@ -262,9 +257,7 @@ func (kw *KubeUnit) kubeLoggingConnectionHandler(timestamps bool, sinceTime time podOptions.SinceTime = &metav1.Time{Time: sinceTime} } - KubeAPIWrapperLock.Lock() - logReq := KubeAPIWrapperInstance.GetLogs(kw.clientset, podNamespace, podName, podOptions) - KubeAPIWrapperLock.Unlock() + logReq := kw.KubeAPIWrapperInstance.GetLogs(kw.clientset, podNamespace, podName, podOptions) // get logstream, with retry for retries := 5; retries > 0; retries-- { logStream, err = logReq.Stream(kw.GetContext()) @@ -297,8 +290,8 @@ func (kw *KubeUnit) kubeLoggingNoReconnect(streamWait *sync.WaitGroup, stdout *S // known issues around this, as logstream can terminate due to log rotation // or 4 hr timeout defer streamWait.Done() - podNamespace := kw.pod.Namespace - podName := kw.pod.Name + podNamespace := kw.Pod.Namespace + podName := kw.Pod.Name logStream, err := kw.kubeLoggingConnectionHandler(false, time.Time{}) if err != nil { return @@ -320,8 +313,8 @@ func (kw *KubeUnit) KubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout defer streamWait.Done() var sinceTime time.Time var err error - podNamespace := kw.pod.Namespace - podName := kw.pod.Name + podNamespace := kw.Pod.Namespace + podName := kw.Pod.Name retries := 5 successfulWrite := false @@ -335,9 +328,7 @@ func (kw *KubeUnit) KubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout // get pod, with retry for retries := 5; retries > 0; retries-- { - KubeAPIWrapperLock.Lock() - kw.pod, err = KubeAPIWrapperInstance.Get(kw.GetContext(), kw.clientset, podNamespace, podName, metav1.GetOptions{}) - KubeAPIWrapperLock.Unlock() + kw.Pod, err = kw.KubeAPIWrapperInstance.Get(kw.GetContext(), kw.clientset, podNamespace, podName, metav1.GetOptions{}) if err == nil { break } @@ -379,6 +370,25 @@ func (kw *KubeUnit) KubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout return } + + if err == io.EOF { + if line != "" { + _, err = stdout.Write([]byte(line + "\n")) + if err != nil { + *stdoutErr = fmt.Errorf("writing final line to stdout: %s", err) + kw.GetWorkceptor().nc.GetLogger().Error("Error writing final line to stdout: %s", err) + + return + } + } + kw.GetWorkceptor().nc.GetLogger().Info("Detected EOF for pod %s/%s.", + podNamespace, + podName, + ) + + return + } + kw.GetWorkceptor().nc.GetLogger().Info( "Detected Error: %s for pod %s/%s. Will retry %d more times.", err, @@ -439,7 +449,6 @@ func (kw *KubeUnit) KubeLoggingWithReconnect(streamWait *sync.WaitGroup, stdout remainingRetries = retries // each time we read successfully, reset this counter successfulWrite = true } - logStream.Close() } } @@ -526,10 +535,8 @@ func (kw *KubeUnit) CreatePod(env map[string]string) error { pod.Spec.Containers[0].Env = evs } - KubeAPIWrapperLock.Lock() - // get pod and store to kw.pod - kw.pod, err = KubeAPIWrapperInstance.Create(kw.GetContext(), kw.clientset, ked.KubeNamespace, pod, metav1.CreateOptions{}) - KubeAPIWrapperLock.Unlock() + // get pod and store to kw.Pod + kw.Pod, err = kw.KubeAPIWrapperInstance.Create(kw.GetContext(), kw.clientset, ked.KubeNamespace, pod, metav1.CreateOptions{}) if err != nil { return err } @@ -544,23 +551,21 @@ func (kw *KubeUnit) CreatePod(env map[string]string) error { status.State = WorkStatePending status.Detail = "Pod created" status.StdoutSize = 0 - status.ExtraData.(*KubeExtraData).PodName = kw.pod.Name + status.ExtraData.(*KubeExtraData).PodName = kw.Pod.Name }) - KubeAPIWrapperLock.Lock() // Wait for the pod to be running - fieldSelector := KubeAPIWrapperInstance.OneTermEqualSelector("metadata.name", kw.pod.Name).String() - KubeAPIWrapperLock.Unlock() + fieldSelector := kw.KubeAPIWrapperInstance.OneTermEqualSelector("metadata.name", kw.Pod.Name).String() lw := &cache.ListWatch{ ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { options.FieldSelector = fieldSelector - return KubeAPIWrapperInstance.List(kw.GetContext(), kw.clientset, ked.KubeNamespace, options) + return kw.KubeAPIWrapperInstance.List(kw.GetContext(), kw.clientset, ked.KubeNamespace, options) }, WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { options.FieldSelector = fieldSelector - return KubeAPIWrapperInstance.Watch(kw.GetContext(), kw.clientset, ked.KubeNamespace, options) + return kw.KubeAPIWrapperInstance.Watch(kw.GetContext(), kw.clientset, ked.KubeNamespace, options) }, } @@ -572,22 +577,20 @@ func (kw *KubeUnit) CreatePod(env map[string]string) error { } time.Sleep(2 * time.Second) - KubeAPIWrapperLock.Lock() - ev, err := KubeAPIWrapperInstance.UntilWithSync(ctxPodReady, lw, &corev1.Pod{}, nil, podRunningAndReady()) - KubeAPIWrapperLock.Unlock() + ev, err := kw.KubeAPIWrapperInstance.UntilWithSync(ctxPodReady, lw, &corev1.Pod{}, nil, podRunningAndReady(*kw)) if ev == nil || ev.Object == nil { return fmt.Errorf("did not return an event while watching pod for work unit %s", kw.ID()) } var ok bool - kw.pod, ok = ev.Object.(*corev1.Pod) + kw.Pod, ok = ev.Object.(*corev1.Pod) if !ok { return fmt.Errorf("watch did not return a pod") } if err == ErrPodCompleted { // Hao: shouldn't we also call kw.Cancel() in these cases? - for _, cstat := range kw.pod.Status.ContainerStatuses { + for _, cstat := range kw.Pod.Status.ContainerStatuses { if cstat.Name == "worker" { if cstat.State.Terminated != nil && cstat.State.Terminated.ExitCode != 0 { return fmt.Errorf("container failed with exit code %d: %s", cstat.State.Terminated.ExitCode, cstat.State.Terminated.Message) @@ -613,12 +616,12 @@ func (kw *KubeUnit) CreatePod(env map[string]string) error { go kw.kubeLoggingNoReconnect(&streamWait, stdout, &stdoutErr) streamWait.Wait() kw.Cancel() - if len(kw.pod.Status.ContainerStatuses) == 1 { - if kw.pod.Status.ContainerStatuses[0].State.Waiting != nil { - return fmt.Errorf("%s, %s", err.Error(), kw.pod.Status.ContainerStatuses[0].State.Waiting.Reason) + if len(kw.Pod.Status.ContainerStatuses) == 1 { + if kw.Pod.Status.ContainerStatuses[0].State.Waiting != nil { + return fmt.Errorf("%s, %s", err.Error(), kw.Pod.Status.ContainerStatuses[0].State.Waiting.Reason) } - for _, cstat := range kw.pod.Status.ContainerStatuses { + for _, cstat := range kw.Pod.Status.ContainerStatuses { if cstat.Name == "worker" { if cstat.State.Waiting != nil { return fmt.Errorf("%s, %s", err.Error(), cstat.State.Waiting.Reason) @@ -664,8 +667,8 @@ func (kw *KubeUnit) runWorkUsingLogger() { skipStdin = false } - podName = kw.pod.Name - podNamespace = kw.pod.Namespace + podName = kw.Pod.Name + podNamespace = kw.Pod.Namespace } else { if podNamespace == "" { errMsg := fmt.Sprintf("Error creating pod: pod namespace is empty for pod %s", @@ -690,9 +693,7 @@ func (kw *KubeUnit) runWorkUsingLogger() { default: } - KubeAPIWrapperLock.Lock() - kw.pod, err = KubeAPIWrapperInstance.Get(kw.GetContext(), kw.clientset, podNamespace, podName, metav1.GetOptions{}) - KubeAPIWrapperLock.Unlock() + kw.Pod, err = kw.KubeAPIWrapperInstance.Get(kw.GetContext(), kw.clientset, podNamespace, podName, metav1.GetOptions{}) if err == nil { break } @@ -717,9 +718,7 @@ func (kw *KubeUnit) runWorkUsingLogger() { // Attach stdin stream to the pod var exec remotecommand.Executor if !skipStdin { - KubeAPIWrapperLock.Lock() - req := KubeAPIWrapperInstance.SubResource(kw.clientset, podName, podNamespace) - KubeAPIWrapperLock.Unlock() + req := kw.KubeAPIWrapperInstance.SubResource(kw.clientset, podName, podNamespace) req.VersionedParams( &corev1.PodExecOptions{ @@ -732,9 +731,7 @@ func (kw *KubeUnit) runWorkUsingLogger() { scheme.ParameterCodec, ) var err error - KubeAPIWrapperLock.Lock() - exec, err = KubeAPIWrapperInstance.NewSPDYExecutor(kw.config, "POST", req.URL()) - KubeAPIWrapperLock.Unlock() + exec, err = kw.KubeAPIWrapperInstance.NewSPDYExecutor(kw.config, "POST", req.URL()) if err != nil { errMsg := fmt.Sprintf("Error creating SPDY executor: %s", err) kw.UpdateBasicStatus(WorkStateFailed, errMsg, 0) @@ -828,12 +825,10 @@ func (kw *KubeUnit) runWorkUsingLogger() { var err error for retries := 5; retries > 0; retries-- { - KubeAPIWrapperLock.Lock() - err = KubeAPIWrapperInstance.StreamWithContext(kw.GetContext(), exec, remotecommand.StreamOptions{ + err = kw.KubeAPIWrapperInstance.StreamWithContext(kw.GetContext(), exec, remotecommand.StreamOptions{ Stdin: stdin, Tty: false, }) - KubeAPIWrapperLock.Unlock() if err != nil { // NOTE: io.EOF for stdin is handled by remotecommand and will not trigger this kw.GetWorkceptor().nc.GetLogger().Warning( @@ -868,7 +863,7 @@ func (kw *KubeUnit) runWorkUsingLogger() { // this is probably not possible... errMsg := fmt.Sprintf("Error reading stdin: %s", stdin.Error()) kw.GetWorkceptor().nc.GetLogger().Error(errMsg) //nolint:govet - kw.GetWorkceptor().nc.GetLogger().Error("Pod status at time of error %s", kw.pod.Status.String()) + kw.GetWorkceptor().nc.GetLogger().Error("Pod status at time of error %s", kw.Pod.Status.String()) kw.UpdateBasicStatus(WorkStateFailed, errMsg, stdout.Size()) close(stdinErrChan) // signal STDOUT goroutine to stop @@ -1208,12 +1203,8 @@ func (kw *KubeUnit) connectUsingKubeconfig() error { var err error ked := kw.UnredactedStatus().ExtraData.(*KubeExtraData) if ked.KubeConfig == "" { - KubeAPIWrapperLock.Lock() - clr := KubeAPIWrapperInstance.NewDefaultClientConfigLoadingRules() - KubeAPIWrapperLock.Unlock() - KubeAPIWrapperLock.Lock() - kw.config, err = KubeAPIWrapperInstance.BuildConfigFromFlags("", clr.GetDefaultFilename()) - KubeAPIWrapperLock.Unlock() + clr := kw.KubeAPIWrapperInstance.NewDefaultClientConfigLoadingRules() + kw.config, err = kw.KubeAPIWrapperInstance.BuildConfigFromFlags("", clr.GetDefaultFilename()) if ked.KubeNamespace == "" { c, err := clr.Load() if err != nil { @@ -1229,9 +1220,7 @@ func (kw *KubeUnit) connectUsingKubeconfig() error { } } } else { - KubeAPIWrapperLock.Lock() - cfg, err := KubeAPIWrapperInstance.NewClientConfigFromBytes([]byte(ked.KubeConfig)) - KubeAPIWrapperLock.Unlock() + cfg, err := kw.KubeAPIWrapperInstance.NewClientConfigFromBytes([]byte(ked.KubeConfig)) if err != nil { return err } @@ -1258,9 +1247,7 @@ func (kw *KubeUnit) connectUsingKubeconfig() error { func (kw *KubeUnit) connectUsingIncluster() error { var err error - KubeAPIWrapperLock.Lock() - kw.config, err = KubeAPIWrapperInstance.InClusterConfig() - KubeAPIWrapperLock.Unlock() + kw.config, err = kw.KubeAPIWrapperInstance.InClusterConfig() if err != nil { return err } @@ -1352,22 +1339,16 @@ func (kw *KubeUnit) connectToKube() error { if ok { switch envRateLimiter { case "never": - KubeAPIWrapperLock.Lock() - kw.config.RateLimiter = KubeAPIWrapperInstance.NewFakeNeverRateLimiter() - KubeAPIWrapperLock.Unlock() + kw.config.RateLimiter = kw.KubeAPIWrapperInstance.NewFakeNeverRateLimiter() case "always": - KubeAPIWrapperLock.Lock() - kw.config.RateLimiter = KubeAPIWrapperInstance.NewFakeAlwaysRateLimiter() - KubeAPIWrapperLock.Unlock() + kw.config.RateLimiter = kw.KubeAPIWrapperInstance.NewFakeAlwaysRateLimiter() default: } kw.GetWorkceptor().nc.GetLogger().Debug("RateLimiter: %s", envRateLimiter) } kw.GetWorkceptor().nc.GetLogger().Debug("QPS: %f, Burst: %d", kw.config.QPS, kw.config.Burst) - KubeAPIWrapperLock.Lock() - kw.clientset, err = KubeAPIWrapperInstance.NewForConfig(kw.config) - KubeAPIWrapperLock.Unlock() + kw.clientset, err = kw.KubeAPIWrapperInstance.NewForConfig(kw.config) if err != nil { return err } @@ -1536,9 +1517,7 @@ func (kw *KubeUnit) Restart() error { if err != nil { kw.GetWorkceptor().nc.GetLogger().Warning("Pod %s could not be deleted: %s", ked.PodName, err.Error()) } else { - KubeAPIWrapperLock.Lock() - err := KubeAPIWrapperInstance.Delete(context.Background(), kw.clientset, ked.KubeNamespace, ked.PodName, metav1.DeleteOptions{}) - KubeAPIWrapperLock.Unlock() + err := kw.KubeAPIWrapperInstance.Delete(context.Background(), kw.clientset, ked.KubeNamespace, ked.PodName, metav1.DeleteOptions{}) if err != nil { kw.GetWorkceptor().nc.GetLogger().Warning("Pod %s could not be deleted: %s", ked.PodName, err.Error()) } @@ -1562,12 +1541,10 @@ func (kw *KubeUnit) Start() error { func (kw *KubeUnit) Cancel() error { kw.CancelContext() kw.UpdateBasicStatus(WorkStateCanceled, "Canceled", -1) - if kw.pod != nil { - KubeAPIWrapperLock.Lock() - err := KubeAPIWrapperInstance.Delete(context.Background(), kw.clientset, kw.pod.Namespace, kw.pod.Name, metav1.DeleteOptions{}) - KubeAPIWrapperLock.Unlock() + if kw.Pod != nil { + err := kw.KubeAPIWrapperInstance.Delete(context.Background(), kw.clientset, kw.Pod.Namespace, kw.Pod.Name, metav1.DeleteOptions{}) if err != nil { - kw.GetWorkceptor().nc.GetLogger().Error("Error deleting pod %s: %s", kw.pod.Name, err) + kw.GetWorkceptor().nc.GetLogger().Error("Error deleting pod %s: %s", kw.Pod.Name, err) } } if kw.GetCancel() != nil { @@ -1616,6 +1593,7 @@ func (cfg KubeWorkerCfg) NewWorker(bwu BaseWorkUnitForWorkUnit, w *Workceptor, u } func (cfg KubeWorkerCfg) NewkubeWorker(bwu BaseWorkUnitForWorkUnit, w *Workceptor, unitID string, workType string, kawi KubeAPIer) WorkUnit { + var kubeAPIWrapperInstance KubeAPIer if bwu == nil { bwu = &BaseWorkUnit{ status: StatusFileData{ @@ -1630,16 +1608,15 @@ func (cfg KubeWorkerCfg) NewkubeWorker(bwu BaseWorkUnitForWorkUnit, w *Workcepto } } - KubeAPIWrapperLock.Lock() if kawi != nil { - KubeAPIWrapperInstance = kawi + kubeAPIWrapperInstance = kawi } else { - KubeAPIWrapperInstance = KubeAPIWrapper{} + kubeAPIWrapperInstance = KubeAPIWrapper{} } - KubeAPIWrapperLock.Unlock() ku := &KubeUnit{ BaseWorkUnitForWorkUnit: bwu, + KubeAPIWrapperInstance: kubeAPIWrapperInstance, authMethod: strings.ToLower(cfg.AuthMethod), streamMethod: strings.ToLower(cfg.StreamMethod), baseParams: cfg.Params, diff --git a/pkg/workceptor/kubernetes_test.go b/pkg/workceptor/kubernetes_test.go index 1c66371a3..4ce071e4e 100644 --- a/pkg/workceptor/kubernetes_test.go +++ b/pkg/workceptor/kubernetes_test.go @@ -487,11 +487,16 @@ func Test_IsCompatibleK8S(t *testing.T) { func TestKubeLoggingWithReconnect(t *testing.T) { var stdinErr error var stdoutErr error - ku, mockBaseWorkUnit, mockNetceptor, w, mockKubeAPI, ctrl, ctx := createKubernetesTestSetup(t) + _, mockBaseWorkUnit, mockNetceptor, w, mockKubeAPI, ctrl, ctx := createKubernetesTestSetup(t) + + pod := &corev1.Pod{TypeMeta: metav1.TypeMeta{}, ObjectMeta: metav1.ObjectMeta{Name: "Test_Name", Namespace: "Test_Namespace"}, Spec: corev1.PodSpec{}, Status: corev1.PodStatus{Phase: corev1.PodRunning}} kw := &workceptor.KubeUnit{ BaseWorkUnitForWorkUnit: mockBaseWorkUnit, + KubeAPIWrapperInstance: mockKubeAPI, + Pod: pod, } + tests := []struct { name string expectedCalls func() @@ -499,29 +504,11 @@ func TestKubeLoggingWithReconnect(t *testing.T) { { name: "Kube error should be read", expectedCalls: func() { - mockBaseWorkUnit.EXPECT().UpdateBasicStatus(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - config := rest.Config{} - mockKubeAPI.EXPECT().InClusterConfig().Return(&config, nil) - mockBaseWorkUnit.EXPECT().GetWorkceptor().Return(w).AnyTimes() - clientset := kubernetes.Clientset{} - mockKubeAPI.EXPECT().NewForConfig(gomock.Any()).Return(&clientset, nil) - lock := &sync.RWMutex{} - mockBaseWorkUnit.EXPECT().GetStatusLock().Return(lock).AnyTimes() - mockBaseWorkUnit.EXPECT().MonitorLocalStatus().AnyTimes() - mockBaseWorkUnit.EXPECT().UnitDir().Return("TestDir2").AnyTimes() - kubeExtraData := workceptor.KubeExtraData{} - status := workceptor.StatusFileData{ExtraData: &kubeExtraData} - mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&status).AnyTimes() - mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(status).AnyTimes() - mockBaseWorkUnit.EXPECT().GetContext().Return(ctx).AnyTimes() - pod := corev1.Pod{TypeMeta: metav1.TypeMeta{}, ObjectMeta: metav1.ObjectMeta{Name: "Test_Name"}, Spec: corev1.PodSpec{}, Status: corev1.PodStatus{}} - mockKubeAPI.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&pod, nil).AnyTimes() - mockBaseWorkUnit.EXPECT().UpdateFullStatus(gomock.Any()).AnyTimes() - field := hasTerm{} - mockKubeAPI.EXPECT().OneTermEqualSelector(gomock.Any(), gomock.Any()).Return(&field).AnyTimes() - ev := watch.Event{Object: &pod} - mockKubeAPI.EXPECT().UntilWithSync(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&ev, nil).AnyTimes() - mockKubeAPI.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&pod, nil).AnyTimes() + mockBaseWorkUnit.EXPECT().GetWorkceptor().Return(w) + mockBaseWorkUnit.EXPECT().GetContext().Return(ctx).Times(3) + mockKubeAPI.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(pod, nil) + logger := logger.NewReceptorLogger("") + mockNetceptor.EXPECT().GetLogger().Return(logger) req := fakerest.RESTClient{ Client: fakerest.CreateHTTPClient(func(request *http.Request) (*http.Response, error) { resp := &http.Response{ @@ -533,21 +520,36 @@ func TestKubeLoggingWithReconnect(t *testing.T) { }), NegotiatedSerializer: scheme.Codecs.WithoutConversion(), } - mockKubeAPI.EXPECT().GetLogs(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(req.Request()).AnyTimes() + mockKubeAPI.EXPECT().GetLogs(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(req.Request()) + }, + }, + { + name: "Kube error 503", + expectedCalls: func() { + mockBaseWorkUnit.EXPECT().GetWorkceptor().Return(w).MinTimes(1) + mockBaseWorkUnit.EXPECT().GetContext().Return(ctx).MinTimes(3) + mockKubeAPI.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(pod, nil) logger := logger.NewReceptorLogger("") - mockNetceptor.EXPECT().GetLogger().Return(logger).AnyTimes() - mockKubeAPI.EXPECT().SubResource(gomock.Any(), gomock.Any(), gomock.Any()).Return(req.Request()).AnyTimes() - exec := ex{} - mockKubeAPI.EXPECT().NewSPDYExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(&exec, nil).AnyTimes() + mockNetceptor.EXPECT().GetLogger().Return(logger).MinTimes(1) + mockBaseWorkUnit.EXPECT().UpdateBasicStatus(gomock.Any(), gomock.Any(), gomock.Any()) + req := fakerest.RESTClient{ + Client: fakerest.CreateHTTPClient(func(request *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, // 503 + Body: io.NopCloser(strings.NewReader("kube error")), + } + + return resp, nil + }), + NegotiatedSerializer: scheme.Codecs.WithoutConversion(), + } + mockKubeAPI.EXPECT().GetLogs(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(req.Request()) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.expectedCalls() - ku.Start() - time.Sleep(10 * time.Millisecond) - kw.CreatePod(nil) wg := &sync.WaitGroup{} wg.Add(1) mockfilesystemer := mock_workceptor.NewMockFileSystemer(ctrl) @@ -555,7 +557,7 @@ func TestKubeLoggingWithReconnect(t *testing.T) { stdout, _ := workceptor.NewStdoutWriter(mockfilesystemer, "") mockFileWC := mock_workceptor.NewMockFileWriteCloser(ctrl) stdout.SetWriter(mockFileWC) - mockFileWC.EXPECT().Write(gomock.AnyOf([]byte("HI\n"), []byte(" kube error\n"))).Return(0, nil).Times(2) + mockFileWC.EXPECT().Write(gomock.AnyOf([]byte("HI\n"), []byte(" kube error\n"))).Return(0, nil).AnyTimes() kw.KubeLoggingWithReconnect(wg, stdout, &stdinErr, &stdoutErr) }) } diff --git a/pkg/workceptor/python.go b/pkg/workceptor/python.go index c9e30745f..d471f32e2 100644 --- a/pkg/workceptor/python.go +++ b/pkg/workceptor/python.go @@ -21,6 +21,18 @@ type pythonUnit struct { config map[string]interface{} } +// NewPythonUnit creates a new pythonUnit using the given parameters. +func NewPythonUnit(baseWorkUnit BaseWorkUnitForWorkUnit, plugin string, function string, config map[string]any) *pythonUnit { + return &pythonUnit{ + commandUnit: commandUnit{ + BaseWorkUnitForWorkUnit: baseWorkUnit, + }, + plugin: plugin, + function: function, + config: config, + } +} + // Start launches a job with given parameters. func (pw *pythonUnit) Start() error { pw.UpdateBasicStatus(WorkStatePending, "[DEPRECATION WARNING] '--work-python' option is not currently being used. This feature will be removed from receptor in a future release.", 0) diff --git a/pkg/workceptor/python_test.go b/pkg/workceptor/python_test.go new file mode 100644 index 000000000..d9d153e4c --- /dev/null +++ b/pkg/workceptor/python_test.go @@ -0,0 +1,156 @@ +package workceptor_test + +import ( + "context" + "fmt" + "math" + "os" + "path" + "slices" + "strings" + "sync" + "testing" + + "github.com/ansible/receptor/pkg/workceptor" + "github.com/ansible/receptor/pkg/workceptor/mock_workceptor" + "go.uber.org/mock/gomock" +) + +func createPythonUnitTestSetup(t *testing.T) (workceptor.WorkUnit, *mock_workceptor.MockBaseWorkUnitForWorkUnit, *mock_workceptor.MockNetceptorForWorkceptor, *workceptor.Workceptor) { + ctrl := gomock.NewController(t) + ctx := context.Background() + + mockBaseWorkUnit := mock_workceptor.NewMockBaseWorkUnitForWorkUnit(ctrl) + mockNetceptor := mock_workceptor.NewMockNetceptorForWorkceptor(ctrl) + mockNetceptor.EXPECT().NodeID().Return("NodeID") + mockNetceptor.EXPECT().GetLogger() + + w, err := workceptor.New(ctx, mockNetceptor, "/tmp") + if err != nil { + t.Errorf("Error while creating Workceptor: %v", err) + } + + mockBaseWorkUnit.EXPECT().Init(w, "", "", workceptor.FileSystem{}, nil) + mockBaseWorkUnit.EXPECT().SetStatusExtraData(gomock.Any()) + workUnit := workceptor.NewRemoteWorker(mockBaseWorkUnit, w, "", "") + + return workUnit, mockBaseWorkUnit, mockNetceptor, w +} + +// Creates a no-op script that can be run during tests. +func createReceptorPythonWorkerScript() error { + tmpDir := "/tmp" + filename := "receptor-python-worker" + absoluteFilename := path.Join(tmpDir, filename) + if workDir, err := os.Getwd(); err != nil { + fmt.Printf("os.Getwd=%s", workDir) + } + + tmpInPath := slices.Contains(strings.Split(os.Getenv("PATH"), ":"), tmpDir) + if !tmpInPath { + newPath := os.Getenv("PATH") + ":" + tmpDir + err := os.Setenv("PATH", newPath) + if err != nil { + return fmt.Errorf("Error setting PATH: %v", err) + } + } + + f, err := os.Create(absoluteFilename) + if err != nil { + return fmt.Errorf("Error creating %s: %v", filename, err) + } + defer f.Close() + + _, err = f.WriteString("#!/usr/bin/env /bin/sh\necho \"\" > /dev/null") + if err != nil { + return fmt.Errorf("Error writing to %s: %v", filename, err) + } + + err = os.Chmod(absoluteFilename, 0o755) + if err != nil { + return fmt.Errorf("Error making %s executable: %v", absoluteFilename, err) + } + + return nil +} + +func TestPythonUnitStartRunsToSuccess(t *testing.T) { + _, mockBaseWorkUnit, _, _ := createPythonUnitTestSetup(t) //nolint:dogsled + pw := workceptor.NewPythonUnit(mockBaseWorkUnit, "", "", nil) + statusLock := &sync.RWMutex{} + + mockBaseWorkUnit.EXPECT().UpdateBasicStatus(gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(statusLock).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{}, + }) + mockBaseWorkUnit.EXPECT().UnitDir() + mockBaseWorkUnit.EXPECT().UpdateFullStatus(gomock.Any()) + mockBaseWorkUnit.EXPECT().MonitorLocalStatus().AnyTimes() + mockBaseWorkUnit.EXPECT().UpdateFullStatus(gomock.Any()).AnyTimes() + + createReceptorPythonWorkerScript() + + err := pw.Start() + if err != nil { + t.Errorf("Error when testing Start method of pythonUnit: %v", err) + } +} + +func TestPythonUnitStartFailsOnInvalidConfig(t *testing.T) { + _, mockBaseWorkUnit, _, _ := createPythonUnitTestSetup(t) //nolint:dogsled + + statusLock := &sync.RWMutex{} + + mockBaseWorkUnit.EXPECT().UpdateBasicStatus(gomock.Any(), gomock.Any(), gomock.Any()).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusLock().Return(statusLock).Times(2) + mockBaseWorkUnit.EXPECT().GetStatusWithoutExtraData().Return(&workceptor.StatusFileData{}) + mockBaseWorkUnit.EXPECT().GetStatusCopy().Return(workceptor.StatusFileData{ + ExtraData: &workceptor.CommandExtraData{}, + }) + + badConfig := map[string]any{ + "badOption": math.Inf(1), + } + pw := workceptor.NewPythonUnit(mockBaseWorkUnit, "", "", badConfig) + + createReceptorPythonWorkerScript() + + err := pw.Start() + if err == nil { + t.Errorf("Expected json marshal error for configuration.") + } +} + +func TestWorkPythonConfigNewWorkerRunsToSuccess(t *testing.T) { + _, mockBaseWorkUnitForWorkUnit, mockNetceptorForWorkceptor, _ := createPythonUnitTestSetup(t) + mockNetceptorForWorkceptor.EXPECT().NodeID().AnyTimes() + + wpc := &workceptor.WorkPythonCfg{} + w, err := workceptor.New(context.TODO(), mockNetceptorForWorkceptor, "") + if err != nil { + t.Errorf("Error when creating workceptor for test: %v", err) + } + workUnit := wpc.NewWorker(mockBaseWorkUnitForWorkUnit, w, "", "") + if workUnit == nil { + t.Error("Returned WorkUnit was nil") + } +} + +func TestWorkPythonConfigRunRunsToSuccess(t *testing.T) { + _, _, mockNetceptorForWorkceptor, _ := createPythonUnitTestSetup(t) //nolint:dogsled + mockNetceptorForWorkceptor.EXPECT().NodeID().AnyTimes() + mockNetceptorForWorkceptor.EXPECT().AddWorkCommand("", false) + + wpc := &workceptor.WorkPythonCfg{} + w, err := workceptor.New(context.TODO(), mockNetceptorForWorkceptor, "") + if err != nil { + t.Errorf("Error when creating workceptor for test: %v", err) + } + workceptor.MainInstance = w + err = wpc.Run() + if err == nil { + t.Errorf("Expected deprecation warning but received none.") + } +}