diff --git a/api/cometbft/mempool/v1/message.go b/api/cometbft/mempool/v1/message.go index a2093ee8682..5bbaa613bae 100644 --- a/api/cometbft/mempool/v1/message.go +++ b/api/cometbft/mempool/v1/message.go @@ -3,9 +3,17 @@ package v1 import ( "fmt" + "github.com/cometbft/cometbft/types" "github.com/cosmos/gogoproto/proto" ) +var ( + _ types.Wrapper = &Txs{} + _ types.Wrapper = &SeenTx{} + _ types.Wrapper = &WantTx{} + _ types.Unwrapper = &Message{} +) + // Wrap implements the p2p Wrapper interface and wraps a mempool message. func (m *Txs) Wrap() proto.Message { mm := &Message{} @@ -13,6 +21,20 @@ func (m *Txs) Wrap() proto.Message { return mm } +// Wrap implements the p2p Wrapper interface and wraps a mempool seen tx message. +func (m *SeenTx) Wrap() proto.Message { + mm := &Message{} + mm.Sum = &Message_SeenTx{SeenTx: m} + return mm +} + +// Wrap implements the p2p Wrapper interface and wraps a mempool want tx message. +func (m *WantTx) Wrap() proto.Message { + mm := &Message{} + mm.Sum = &Message_WantTx{WantTx: m} + return mm +} + // Unwrap implements the p2p Wrapper interface and unwraps a wrapped mempool // message. func (m *Message) Unwrap() (proto.Message, error) { @@ -20,6 +42,12 @@ func (m *Message) Unwrap() (proto.Message, error) { case *Message_Txs: return m.GetTxs(), nil + case *Message_SeenTx: + return m.GetSeenTx(), nil + + case *Message_WantTx: + return m.GetWantTx(), nil + default: return nil, fmt.Errorf("unknown message: %T", msg) } diff --git a/api/cometbft/mempool/v1/types.pb.go b/api/cometbft/mempool/v1/types.pb.go index dac12ec1685..fdbe1b5b141 100644 --- a/api/cometbft/mempool/v1/types.pb.go +++ b/api/cometbft/mempool/v1/types.pb.go @@ -67,6 +67,96 @@ func (m *Txs) GetTxs() [][]byte { return nil } +// SeenTx contains a list of transaction keys seen by the sender. +type SeenTx struct { + TxKey []byte `protobuf:"bytes,1,opt,name=tx_key,json=txKey,proto3" json:"tx_key,omitempty"` +} + +func (m *SeenTx) Reset() { *m = SeenTx{} } +func (m *SeenTx) String() string { return proto.CompactTextString(m) } +func (*SeenTx) ProtoMessage() {} +func (*SeenTx) Descriptor() ([]byte, []int) { + return fileDescriptor_d8bb39f484575b79, []int{1} +} +func (m *SeenTx) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *SeenTx) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_SeenTx.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *SeenTx) XXX_Merge(src proto.Message) { + xxx_messageInfo_SeenTx.Merge(m, src) +} +func (m *SeenTx) XXX_Size() int { + return m.Size() +} +func (m *SeenTx) XXX_DiscardUnknown() { + xxx_messageInfo_SeenTx.DiscardUnknown(m) +} + +var xxx_messageInfo_SeenTx proto.InternalMessageInfo + +func (m *SeenTx) GetTxKey() []byte { + if m != nil { + return m.TxKey + } + return nil +} + +// WantTx contains a list of transaction keys wanted by the sender. +type WantTx struct { + TxKey []byte `protobuf:"bytes,1,opt,name=tx_key,json=txKey,proto3" json:"tx_key,omitempty"` +} + +func (m *WantTx) Reset() { *m = WantTx{} } +func (m *WantTx) String() string { return proto.CompactTextString(m) } +func (*WantTx) ProtoMessage() {} +func (*WantTx) Descriptor() ([]byte, []int) { + return fileDescriptor_d8bb39f484575b79, []int{2} +} +func (m *WantTx) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *WantTx) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_WantTx.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *WantTx) XXX_Merge(src proto.Message) { + xxx_messageInfo_WantTx.Merge(m, src) +} +func (m *WantTx) XXX_Size() int { + return m.Size() +} +func (m *WantTx) XXX_DiscardUnknown() { + xxx_messageInfo_WantTx.DiscardUnknown(m) +} + +var xxx_messageInfo_WantTx proto.InternalMessageInfo + +func (m *WantTx) GetTxKey() []byte { + if m != nil { + return m.TxKey + } + return nil +} + // Message is an abstract mempool message. type Message struct { // Sum of all possible messages. @@ -74,6 +164,8 @@ type Message struct { // Types that are valid to be assigned to Sum: // // *Message_Txs + // *Message_SeenTx + // *Message_WantTx Sum isMessage_Sum `protobuf_oneof:"sum"` } @@ -81,7 +173,7 @@ func (m *Message) Reset() { *m = Message{} } func (m *Message) String() string { return proto.CompactTextString(m) } func (*Message) ProtoMessage() {} func (*Message) Descriptor() ([]byte, []int) { - return fileDescriptor_d8bb39f484575b79, []int{1} + return fileDescriptor_d8bb39f484575b79, []int{3} } func (m *Message) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -119,8 +211,16 @@ type isMessage_Sum interface { type Message_Txs struct { Txs *Txs `protobuf:"bytes,1,opt,name=txs,proto3,oneof" json:"txs,omitempty"` } +type Message_SeenTx struct { + SeenTx *SeenTx `protobuf:"bytes,2,opt,name=seen_tx,json=seenTx,proto3,oneof" json:"seen_tx,omitempty"` +} +type Message_WantTx struct { + WantTx *WantTx `protobuf:"bytes,3,opt,name=want_tx,json=wantTx,proto3,oneof" json:"want_tx,omitempty"` +} -func (*Message_Txs) isMessage_Sum() {} +func (*Message_Txs) isMessage_Sum() {} +func (*Message_SeenTx) isMessage_Sum() {} +func (*Message_WantTx) isMessage_Sum() {} func (m *Message) GetSum() isMessage_Sum { if m != nil { @@ -136,34 +236,57 @@ func (m *Message) GetTxs() *Txs { return nil } +func (m *Message) GetSeenTx() *SeenTx { + if x, ok := m.GetSum().(*Message_SeenTx); ok { + return x.SeenTx + } + return nil +} + +func (m *Message) GetWantTx() *WantTx { + if x, ok := m.GetSum().(*Message_WantTx); ok { + return x.WantTx + } + return nil +} + // XXX_OneofWrappers is for the internal use of the proto package. func (*Message) XXX_OneofWrappers() []interface{} { return []interface{}{ (*Message_Txs)(nil), + (*Message_SeenTx)(nil), + (*Message_WantTx)(nil), } } func init() { proto.RegisterType((*Txs)(nil), "cometbft.mempool.v1.Txs") + proto.RegisterType((*SeenTx)(nil), "cometbft.mempool.v1.SeenTx") + proto.RegisterType((*WantTx)(nil), "cometbft.mempool.v1.WantTx") proto.RegisterType((*Message)(nil), "cometbft.mempool.v1.Message") } func init() { proto.RegisterFile("cometbft/mempool/v1/types.proto", fileDescriptor_d8bb39f484575b79) } var fileDescriptor_d8bb39f484575b79 = []byte{ - // 182 bytes of a gzipped FileDescriptorProto + // 272 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4f, 0xce, 0xcf, 0x4d, 0x2d, 0x49, 0x4a, 0x2b, 0xd1, 0xcf, 0x4d, 0xcd, 0x2d, 0xc8, 0xcf, 0xcf, 0xd1, 0x2f, 0x33, 0xd4, 0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x86, 0x29, 0xd0, 0x83, 0x2a, 0xd0, 0x2b, 0x33, 0x54, 0x12, 0xe7, 0x62, 0x0e, 0xa9, 0x28, 0x16, 0x12, 0xe0, 0x62, - 0x2e, 0xa9, 0x28, 0x96, 0x60, 0x54, 0x60, 0xd6, 0xe0, 0x09, 0x02, 0x31, 0x95, 0xec, 0xb8, 0xd8, - 0x7d, 0x53, 0x8b, 0x8b, 0x13, 0xd3, 0x53, 0x85, 0x74, 0x60, 0x92, 0x8c, 0x1a, 0xdc, 0x46, 0x12, - 0x7a, 0x58, 0x8c, 0xd1, 0x0b, 0xa9, 0x28, 0xf6, 0x60, 0x00, 0x6b, 0x74, 0x62, 0xe5, 0x62, 0x2e, - 0x2e, 0xcd, 0x75, 0xf2, 0x3b, 0xf1, 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6, 0x07, 0x8f, 0xe4, - 0x18, 0x27, 0x3c, 0x96, 0x63, 0xb8, 0xf0, 0x58, 0x8e, 0xe1, 0xc6, 0x63, 0x39, 0x86, 0x28, 0x93, - 0xf4, 0xcc, 0x92, 0x8c, 0xd2, 0x24, 0x90, 0x39, 0xfa, 0x70, 0x37, 0xc3, 0x19, 0x89, 0x05, 0x99, - 0xfa, 0x58, 0x7c, 0x92, 0xc4, 0x06, 0xf6, 0x84, 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0x3f, 0x09, - 0x89, 0xce, 0xe7, 0x00, 0x00, 0x00, + 0x2e, 0xa9, 0x28, 0x96, 0x60, 0x54, 0x60, 0xd6, 0xe0, 0x09, 0x02, 0x31, 0x95, 0xe4, 0xb9, 0xd8, + 0x82, 0x53, 0x53, 0xf3, 0x42, 0x2a, 0x84, 0x44, 0xb9, 0xd8, 0x4a, 0x2a, 0xe2, 0xb3, 0x53, 0x2b, + 0x25, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x58, 0x4b, 0x2a, 0xbc, 0x53, 0x2b, 0x41, 0x0a, 0xc2, + 0x13, 0xf3, 0x4a, 0x70, 0x2b, 0x58, 0xc7, 0xc8, 0xc5, 0xee, 0x9b, 0x5a, 0x5c, 0x9c, 0x98, 0x9e, + 0x2a, 0xa4, 0x03, 0x33, 0x9f, 0x51, 0x83, 0xdb, 0x48, 0x42, 0x0f, 0x8b, 0x4b, 0xf4, 0x42, 0x2a, + 0x8a, 0x3d, 0x18, 0xc0, 0x76, 0x0b, 0x99, 0x71, 0xb1, 0x17, 0xa7, 0xa6, 0xe6, 0xc5, 0x97, 0x54, + 0x48, 0x30, 0x81, 0x75, 0x48, 0x63, 0xd5, 0x01, 0x71, 0x9f, 0x07, 0x43, 0x10, 0x5b, 0x31, 0xc4, + 0xa5, 0x66, 0x5c, 0xec, 0xe5, 0x89, 0x79, 0x25, 0x20, 0x7d, 0xcc, 0x78, 0xf4, 0x41, 0x9c, 0x0d, + 0xd2, 0x57, 0x0e, 0x66, 0x39, 0xb1, 0x72, 0x31, 0x17, 0x97, 0xe6, 0x3a, 0xf9, 0x9d, 0x78, 0x24, + 0xc7, 0x78, 0xe1, 0x91, 0x1c, 0xe3, 0x83, 0x47, 0x72, 0x8c, 0x13, 0x1e, 0xcb, 0x31, 0x5c, 0x78, + 0x2c, 0xc7, 0x70, 0xe3, 0xb1, 0x1c, 0x43, 0x94, 0x49, 0x7a, 0x66, 0x49, 0x46, 0x69, 0x12, 0xc8, + 0x34, 0x7d, 0x78, 0x30, 0xc3, 0x19, 0x89, 0x05, 0x99, 0xfa, 0x58, 0x02, 0x3f, 0x89, 0x0d, 0x1c, + 0xee, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x42, 0x48, 0x11, 0xab, 0x9a, 0x01, 0x00, 0x00, } func (m *Txs) Marshal() (dAtA []byte, err error) { @@ -198,6 +321,66 @@ func (m *Txs) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *SeenTx) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SeenTx) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *SeenTx) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.TxKey) > 0 { + i -= len(m.TxKey) + copy(dAtA[i:], m.TxKey) + i = encodeVarintTypes(dAtA, i, uint64(len(m.TxKey))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *WantTx) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *WantTx) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *WantTx) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.TxKey) > 0 { + i -= len(m.TxKey) + copy(dAtA[i:], m.TxKey) + i = encodeVarintTypes(dAtA, i, uint64(len(m.TxKey))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + func (m *Message) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -251,6 +434,48 @@ func (m *Message_Txs) MarshalToSizedBuffer(dAtA []byte) (int, error) { } return len(dAtA) - i, nil } +func (m *Message_SeenTx) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Message_SeenTx) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + if m.SeenTx != nil { + { + size, err := m.SeenTx.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintTypes(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + } + return len(dAtA) - i, nil +} +func (m *Message_WantTx) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Message_WantTx) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + if m.WantTx != nil { + { + size, err := m.WantTx.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintTypes(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x1a + } + return len(dAtA) - i, nil +} func encodeVarintTypes(dAtA []byte, offset int, v uint64) int { offset -= sovTypes(v) base := offset @@ -277,6 +502,32 @@ func (m *Txs) Size() (n int) { return n } +func (m *SeenTx) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.TxKey) + if l > 0 { + n += 1 + l + sovTypes(uint64(l)) + } + return n +} + +func (m *WantTx) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.TxKey) + if l > 0 { + n += 1 + l + sovTypes(uint64(l)) + } + return n +} + func (m *Message) Size() (n int) { if m == nil { return 0 @@ -301,6 +552,30 @@ func (m *Message_Txs) Size() (n int) { } return n } +func (m *Message_SeenTx) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.SeenTx != nil { + l = m.SeenTx.Size() + n += 1 + l + sovTypes(uint64(l)) + } + return n +} +func (m *Message_WantTx) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.WantTx != nil { + l = m.WantTx.Size() + n += 1 + l + sovTypes(uint64(l)) + } + return n +} func sovTypes(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 @@ -390,6 +665,174 @@ func (m *Txs) Unmarshal(dAtA []byte) error { } return nil } +func (m *SeenTx) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTypes + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SeenTx: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SeenTx: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field TxKey", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTypes + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthTypes + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthTypes + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.TxKey = append(m.TxKey[:0], dAtA[iNdEx:postIndex]...) + if m.TxKey == nil { + m.TxKey = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTypes(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthTypes + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *WantTx) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTypes + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: WantTx: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: WantTx: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field TxKey", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTypes + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthTypes + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthTypes + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.TxKey = append(m.TxKey[:0], dAtA[iNdEx:postIndex]...) + if m.TxKey == nil { + m.TxKey = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTypes(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthTypes + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func (m *Message) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 @@ -454,6 +897,76 @@ func (m *Message) Unmarshal(dAtA []byte) error { } m.Sum = &Message_Txs{v} iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SeenTx", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTypes + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthTypes + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthTypes + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + v := &SeenTx{} + if err := v.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + m.Sum = &Message_SeenTx{v} + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field WantTx", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTypes + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthTypes + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthTypes + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + v := &WantTx{} + if err := v.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + m.Sum = &Message_WantTx{v} + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipTypes(dAtA[iNdEx:]) diff --git a/config/config.go b/config/config.go index d74e8084aa5..d793a0ebe18 100644 --- a/config/config.go +++ b/config/config.go @@ -49,6 +49,7 @@ const ( v2 = "v2" MempoolTypeFlood = "flood" + MempoolTypeCat = "cat" MempoolTypeNop = "nop" ) @@ -845,6 +846,7 @@ type MempoolConfig struct { // Possible types: // - "flood" : concurrent linked list mempool with flooding gossip protocol // (default) + // - "cat" : 'push-pull'-type protocol for Content-Addressable Transactions // - "nop" : nop-mempool (short for no operation; the ABCI app is // responsible for storing, disseminating and proposing txs). // "create_empty_blocks=false" is not supported. diff --git a/config/toml.go b/config/toml.go index eeb567deee5..dcd7766a4e1 100644 --- a/config/toml.go +++ b/config/toml.go @@ -385,6 +385,7 @@ dial_timeout = "{{ .P2P.DialTimeout }}" # Possible types: # - "flood" : concurrent linked list mempool with flooding gossip protocol # (default) +# - "cat" : 'push-pull'-type protocol for Content-Addressable Transactions # - "nop" : nop-mempool (short for no operation; the ABCI app is responsible # for storing, disseminating and proposing txs). "create_empty_blocks=false" is # not supported. diff --git a/internal/consensus/replay_stubs.go b/internal/consensus/replay_stubs.go index 0caaddc109f..e8cfd89f398 100644 --- a/internal/consensus/replay_stubs.go +++ b/internal/consensus/replay_stubs.go @@ -25,6 +25,12 @@ func (emptyMempool) CheckTx(types.Tx) (*abcicli.ReqRes, error) { return nil, nil } +func (emptyMempool) CheckNewTx(types.Tx) (*abcicli.ReqRes, error) { + return nil, nil +} + +func (emptyMempool) InvokeNewTxReceivedOnReactor(types.TxKey) {} + func (txmp emptyMempool) RemoveTxByKey(types.TxKey) error { return nil } diff --git a/mempool/cache.go b/mempool/cache.go index 6a854554d1c..18772e40849 100644 --- a/mempool/cache.go +++ b/mempool/cache.go @@ -7,67 +7,64 @@ import ( "github.com/cometbft/cometbft/types" ) -// TxCache defines an interface for raw transaction caching in a mempool. +// TxCache defines an interface for transaction caching. // Currently, a TxCache does not allow direct reading or getting of transaction // values. A TxCache is used primarily to push transactions and removing // transactions. Pushing via Push returns a boolean telling the caller if the // transaction already exists in the cache or not. -type TxCache interface { +type TxCache[T comparable] interface { // Reset resets the cache to an empty state. Reset() - // Push adds the given raw transaction to the cache and returns true if it was + // Push adds the given transaction key to the cache and returns true if it was // newly added. Otherwise, it returns false. - Push(tx types.Tx) bool + Push(v T) bool - // Remove removes the given raw transaction from the cache. - Remove(tx types.Tx) + // Remove removes the given transaction from the cache. + Remove(v T) // Has reports whether tx is present in the cache. Checking for presence is // not treated as an access of the value. - Has(tx types.Tx) bool + Has(v T) bool } -var _ TxCache = (*LRUTxCache)(nil) +var _ TxCache[types.TxKey] = (*LRUTxCache[types.TxKey])(nil) -// LRUTxCache maintains a thread-safe LRU cache of raw transactions. The cache -// only stores the hash of the raw transaction. -type LRUTxCache struct { +// LRUTxCache maintains a thread-safe LRU cache of transaction hashes (keys). +type LRUTxCache[T comparable] struct { mtx cmtsync.Mutex size int - cacheMap map[types.TxKey]*list.Element + cacheMap map[T]*list.Element list *list.List } -func NewLRUTxCache(cacheSize int) *LRUTxCache { - return &LRUTxCache{ +func NewLRUTxCache[T comparable](cacheSize int) *LRUTxCache[T] { + return &LRUTxCache[T]{ size: cacheSize, - cacheMap: make(map[types.TxKey]*list.Element, cacheSize), + cacheMap: make(map[T]*list.Element, cacheSize), list: list.New(), } } // GetList returns the underlying linked-list that backs the LRU cache. Note, // this should be used for testing purposes only! -func (c *LRUTxCache) GetList() *list.List { +func (c *LRUTxCache[T]) GetList() *list.List { return c.list } -func (c *LRUTxCache) Reset() { +func (c *LRUTxCache[T]) Reset() { c.mtx.Lock() defer c.mtx.Unlock() - c.cacheMap = make(map[types.TxKey]*list.Element, c.size) + c.cacheMap = make(map[T]*list.Element, c.size) c.list.Init() } -func (c *LRUTxCache) Push(tx types.Tx) bool { +func (c *LRUTxCache[T]) Push(v T) bool { c.mtx.Lock() defer c.mtx.Unlock() - key := tx.Key() - - moved, ok := c.cacheMap[key] + moved, ok := c.cacheMap[v] if ok { c.list.MoveToBack(moved) return false @@ -76,45 +73,44 @@ func (c *LRUTxCache) Push(tx types.Tx) bool { if c.list.Len() >= c.size { front := c.list.Front() if front != nil { - frontKey := front.Value.(types.TxKey) + frontKey := front.Value.(T) delete(c.cacheMap, frontKey) c.list.Remove(front) } } - e := c.list.PushBack(key) - c.cacheMap[key] = e + e := c.list.PushBack(v) + c.cacheMap[v] = e return true } -func (c *LRUTxCache) Remove(tx types.Tx) { +func (c *LRUTxCache[T]) Remove(v T) { c.mtx.Lock() defer c.mtx.Unlock() - key := tx.Key() - e := c.cacheMap[key] - delete(c.cacheMap, key) + e := c.cacheMap[v] + delete(c.cacheMap, v) if e != nil { c.list.Remove(e) } } -func (c *LRUTxCache) Has(tx types.Tx) bool { +func (c *LRUTxCache[T]) Has(v T) bool { c.mtx.Lock() defer c.mtx.Unlock() - _, ok := c.cacheMap[tx.Key()] + _, ok := c.cacheMap[v] return ok } -// NopTxCache defines a no-op raw transaction cache. -type NopTxCache struct{} +// NopTxCache defines a no-op transaction cache. +type NopTxCache[T comparable] struct{} -var _ TxCache = (*NopTxCache)(nil) +var _ TxCache[types.TxKey] = (*NopTxCache[types.TxKey])(nil) -func (NopTxCache) Reset() {} -func (NopTxCache) Push(types.Tx) bool { return true } -func (NopTxCache) Remove(types.Tx) {} -func (NopTxCache) Has(types.Tx) bool { return false } +func (NopTxCache[T]) Reset() {} +func (NopTxCache[T]) Push(types.TxKey) bool { return true } +func (NopTxCache[T]) Remove(types.TxKey) {} +func (NopTxCache[T]) Has(types.TxKey) bool { return false } diff --git a/mempool/cache_bench_test.go b/mempool/cache_bench_test.go index 1c26999d106..282d8a2711d 100644 --- a/mempool/cache_bench_test.go +++ b/mempool/cache_bench_test.go @@ -3,10 +3,12 @@ package mempool import ( "encoding/binary" "testing" + + "github.com/cometbft/cometbft/types" ) func BenchmarkCacheInsertTime(b *testing.B) { - cache := NewLRUTxCache(b.N) + cache := NewLRUTxCache[types.TxKey](b.N) txs := make([][]byte, b.N) for i := 0; i < b.N; i++ { @@ -17,25 +19,25 @@ func BenchmarkCacheInsertTime(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - cache.Push(txs[i]) + cache.Push(types.Tx(txs[i]).Key()) } } // This benchmark is probably skewed, since we actually will be removing // txs in parallel, which may cause some overhead due to mutex locking. func BenchmarkCacheRemoveTime(b *testing.B) { - cache := NewLRUTxCache(b.N) + cache := NewLRUTxCache[types.TxKey](b.N) txs := make([][]byte, b.N) for i := 0; i < b.N; i++ { txs[i] = make([]byte, 8) binary.BigEndian.PutUint64(txs[i], uint64(i)) - cache.Push(txs[i]) + cache.Push(types.Tx(txs[i]).Key()) } b.ResetTimer() for i := 0; i < b.N; i++ { - cache.Remove(txs[i]) + cache.Remove(types.Tx(txs[i]).Key()) } } diff --git a/mempool/cache_test.go b/mempool/cache_test.go index 35089ba4790..8cea584fa1e 100644 --- a/mempool/cache_test.go +++ b/mempool/cache_test.go @@ -15,18 +15,18 @@ import ( ) func TestCacheRemove(t *testing.T) { - cache := NewLRUTxCache(100) + cache := NewLRUTxCache[types.TxKey](100) numTxs := 10 - txs := make([][]byte, numTxs) + txs := make([]types.TxKey, numTxs) for i := 0; i < numTxs; i++ { // probability of collision is 2**-256 txBytes := make([]byte, 32) _, err := rand.Read(txBytes) require.NoError(t, err) - txs[i] = txBytes - cache.Push(txBytes) + txs[i] = types.Tx(txBytes).Key() + cache.Push(txs[i]) // make sure its added to both the linked list and the map require.Len(t, cache.cacheMap, i+1) @@ -85,7 +85,7 @@ func TestCacheAfterUpdate(t *testing.T) { } } - cache := mp.cache.(*LRUTxCache) + cache := mp.cache.(*LRUTxCache[types.TxKey]) node := cache.GetList().Front() counter := 0 for node != nil { diff --git a/mempool/cat/reactor.go b/mempool/cat/reactor.go new file mode 100644 index 00000000000..48194fadc75 --- /dev/null +++ b/mempool/cat/reactor.go @@ -0,0 +1,457 @@ +package cat + +import ( + "errors" + "fmt" + "sync" + "time" + + abci "github.com/cometbft/cometbft/abci/types" + protomem "github.com/cometbft/cometbft/api/cometbft/mempool/v1" + cfg "github.com/cometbft/cometbft/config" + "github.com/cometbft/cometbft/crypto/tmhash" + "github.com/cometbft/cometbft/libs/log" + "github.com/cometbft/cometbft/mempool" + "github.com/cometbft/cometbft/p2p" + "github.com/cometbft/cometbft/types" +) + +const ( + // default duration to wait before considering a peer non-responsive + // and searching for the tx from a new peer. + defaultGossipDelay = 1000 * time.Millisecond + + // Content Addressable Tx Pool gossips state based messages (SeenTx and WantTx) on a separate channel + // for cross compatibility. + MempoolStateChannel = byte(0x31) + + // peerHeightDiff signifies the tolerance in difference in height between the peer and the height + // the node received the tx. + peerHeightDiff = 10 +) + +// Mempool reactor that implements a push-pull gossip protocol. +type Reactor struct { + mempool.WaitSyncReactor + mempool *mempool.CListMempool + + peerIDs sync.Map // set of connected peers + requests *requestScheduler // to track requested transactions + + // Thread-safe list of transactions peers have seen that we have not yet seen + seenByPeersSet *SeenTxSet +} + +// NewReactor returns a new Reactor with the given config and mempool. +func NewReactor(config *cfg.MempoolConfig, mp *mempool.CListMempool, waitSync bool, logger log.Logger) *Reactor { + memR := &Reactor{ + WaitSyncReactor: *mempool.NewWaitSyncReactor(config, waitSync), + mempool: mp, + requests: newRequestScheduler(defaultGossipDelay, defaultGlobalRequestTimeout), + seenByPeersSet: NewSeenTxSet(), + } + memR.BaseReactor = *p2p.NewBaseReactor("Mempool", memR) + + memR.SetLogger(logger) + memR.mempool.SetTxRemovedCallback(func(txKey types.TxKey) { + memR.seenByPeersSet.RemoveKey(txKey) + }) + memR.mempool.SetNewTxReceivedCallback(func(txKey types.TxKey) { + // If we don't find the tx in the mempool, probably it is because it was + // invalid, so don't broadcast. + if entry := memR.mempool.GetEntry(txKey); !entry.IsEmpty() { + go memR.broadcastNewTx(entry) + } + }) + return memR +} + +// InitPeer implements Reactor by creating a state for the peer. +func (memR *Reactor) InitPeer(peer p2p.Peer) p2p.Peer { + memR.peerIDs.Store(peer.ID(), struct{}{}) + return peer +} + +// SetLogger sets the Logger on the reactor and the underlying mempool. +func (memR *Reactor) SetLogger(l log.Logger) { + memR.Logger = l +} + +// OnStart implements p2p.BaseReactor. +func (memR *Reactor) OnStart() error { + if memR.WaitSync() { + memR.Logger.Info("Starting reactor in sync mode: tx propagation will start once sync completes") + } + if !memR.Config.Broadcast { + memR.Logger.Info("Tx broadcasting is disabled") + } + return nil +} + +// OnStop implements Service. +func (memR *Reactor) OnStop() { + // stop all the timers tracking outbound requests + memR.requests.Close() +} + +// GetChannels implements Reactor by returning the list of channels for this +// reactor. +func (memR *Reactor) GetChannels() []*p2p.ChannelDescriptor { + largestTx := make([]byte, memR.Config.MaxTxBytes) + batchMsg := protomem.Message{ + Sum: &protomem.Message_Txs{ + Txs: &protomem.Txs{Txs: [][]byte{largestTx}}, + }, + } + + stateMsg := protomem.Message{ + Sum: &protomem.Message_SeenTx{ + SeenTx: &protomem.SeenTx{ + TxKey: make([]byte, tmhash.Size), + }, + }, + } + + return []*p2p.ChannelDescriptor{ + { + ID: mempool.MempoolChannel, + Priority: 6, + RecvMessageCapacity: batchMsg.Size(), + MessageType: &protomem.Message{}, + }, + { + ID: MempoolStateChannel, + Priority: 5, + RecvMessageCapacity: stateMsg.Size(), + MessageType: &protomem.Message{}, + }, + } +} + +// RemovePeer implements Reactor. For all current outbound requests to this +// peer it will find a new peer to rerequest the same transactions. +func (memR *Reactor) RemovePeer(peer p2p.Peer, _ interface{}) { + memR.peerIDs.Delete(peer.ID()) + // remove and rerequest all pending outbound requests to that peer since we know + // we won't receive any responses from them. + outboundRequests := memR.requests.ClearAllRequestsFrom(peer.ID()) + for key := range outboundRequests { + memR.mempool.Metrics.RequestedTxs.Add(1) + memR.findNewPeerToRequestTx(key) + } +} + +// Receive implements Reactor. +// It processes one of three messages: Txs, SeenTx, WantTx. +func (memR *Reactor) Receive(e p2p.Envelope) { + if memR.WaitSync() { + memR.Logger.Debug("Ignored message received while syncing", "msg", e.Message) + return + } + + switch msg := e.Message.(type) { + // A peer has sent us one or more transactions. This could be either because we requested them + // or because the peer received a new transaction and is broadcasting it to us. + // NOTE: This setup also means that we can support older mempool implementations that simply + // flooded the network with transactions. + case *protomem.Txs: + protoTxs := msg.GetTxs() + memR.Logger.Debug("Received Txs", "src", e.Src, "chId", e.ChannelID, "msg", e.Message, "len", len(protoTxs)) + if len(protoTxs) == 0 { + memR.Logger.Error("received empty txs from peer", "src", e.Src) + return + } + + peerID := e.Src.ID() + for _, txBytes := range protoTxs { + tx := types.Tx(txBytes) + key := tx.Key() + + // If we requested the transaction, we mark it as received. + if memR.requests.Has(peerID, key) { + memR.requests.MarkReceived(peerID, key) + memR.Logger.Debug("received a response for a requested transaction", "peerID", peerID, "txKey", key) + } else { + // If we didn't request the transaction we simply mark the peer as having the + // tx (we'd have already done it if we were requesting the tx). + memR.markPeerHasTx(peerID, key) + memR.Logger.Debug("received new transaction", "peerID", peerID, "txKey", key) + } + + reqRes, err := memR.mempool.CheckTx(tx) + if errors.Is(err, mempool.ErrTxInCache) { + memR.Logger.Debug("Tx already exists in cache", "tx", tx.String()) + return + } + if err != nil { + memR.Logger.Info("Could not check tx", "tx", tx.String(), "err", err) + return + } + // Record the sender only when the transaction is valid and, as + // a consequence, added to the mempool. Senders are stored until + // the transaction is removed from the mempool. Note that it's + // possible a tx is still in the cache but no longer in the + // mempool. For example, after committing a block, txs are + // removed from mempool but not the cache. + reqRes.SetCallback(func(res *abci.Response) { + if res.GetCheckTx().Code == abci.CodeTypeOK { + memR.markPeerHasTx(e.Src.ID(), tx.Key()) + + // We broadcast only transactions that we deem valid and + // actually have in our mempool. + memR.broadcastSeenTx(key) + } + }) + } + + // A peer has indicated to us that it has a transaction. We first verify the txKey and + // mark that peer as having the transaction. Then we proceed with the following logic: + // + // 1. If we have the transaction, we do nothing. + // 2. If we don't yet have the tx but have an outgoing request for it, we do nothing. + // 3. If we recently evicted the tx and still don't have space for it, we do nothing. + // 4. Else, we request the transaction from that peer. + case *protomem.SeenTx: + txKey, err := types.TxKeyFromBytes(msg.TxKey) + if err != nil { + memR.Logger.Error("peer sent SeenTx with incorrect tx key", "err", err) + memR.Switch.StopPeerForError(e.Src, err) + return + } + memR.Logger.Debug("Received SeenTx", "src", e.Src, "chId", e.ChannelID, "txKey", txKey) + peerID := e.Src.ID() + memR.markPeerHasTx(peerID, txKey) + + // Check if we don't already have the transaction and that it was recently rejected + if memR.mempool.InMempool(txKey) || memR.mempool.InCache(txKey) || memR.mempool.WasRejected(txKey) { + memR.Logger.Debug("received a seen tx for a tx we already have or is in cache or was rejected", "txKey", txKey) + return + } + + // If we are already requesting that tx, then we don't need to go any further. + if _, exists := memR.requests.ForTx(txKey); exists { + memR.Logger.Debug("received a SeenTx message for a transaction we are already requesting", "txKey", txKey) + return + } + + // We don't have the transaction, nor are we requesting it so we send the node + // a want msg + memR.requestTx(txKey, e.Src.ID()) + + // A peer is requesting a transaction that we have claimed to have. Find the specified + // transaction and broadcast it to the peer. We may no longer have the transaction + case *protomem.WantTx: + txKey, err := types.TxKeyFromBytes(msg.TxKey) + if err != nil { + memR.Logger.Error("peer sent WantTx with incorrect tx key", "err", err) + memR.Switch.StopPeerForError(e.Src, err) + return + } + memR.Logger.Debug("Received SeenTx", "src", e.Src, "chId", e.ChannelID, "txKey", txKey) + memR.sendRequestedTx(txKey, e.Src) + + default: + memR.Logger.Error("Received unknown message type", "src", e.Src, "chId", e.ChannelID, "msg", e.Message) + memR.Switch.StopPeerForError(e.Src, fmt.Errorf("mempool cannot handle message of type: %T", e.Message)) + return + } +} + +func (memR *Reactor) sendRequestedTx(txKey types.TxKey, peer p2p.Peer) { + if !memR.Config.Broadcast { + return + } + + if entry := memR.mempool.GetEntry(txKey); entry != nil { + memR.Logger.Debug("sending a tx in response to a want msg", "peer", peer.ID()) + txsMsg := p2p.Envelope{ + ChannelID: mempool.MempoolChannel, + Message: &protomem.Txs{Txs: [][]byte{entry.Tx()}}, + } + if peer.Send(txsMsg) { + memR.markPeerHasTx(peer.ID(), txKey) + } else { + memR.Logger.Error("send Txs message with requested transactions failed", "txKey", txKey, "peerID", peer.ID()) + } + } +} + +// PeerHasTx marks that the transaction has been seen by a peer. +func (memR *Reactor) markPeerHasTx(peerID p2p.ID, txKey types.TxKey) { + memR.Logger.Debug("mark that peer has tx", "peer", peerID, "txKey", txKey.String()) + memR.seenByPeersSet.Add(txKey, peerID) +} + +// PeerState describes the state of a peer. +type PeerState interface { + GetHeight() int64 +} + +// broadcastSeenTx broadcasts a SeenTx message to all peers unless we +// know they have already seen the transaction. +func (memR *Reactor) broadcastSeenTx(txKey types.TxKey) { + if !memR.Config.Broadcast { + return + } + + memR.Logger.Debug("Broadcasting SeenTx...", "tx", txKey.String()) + + msg := p2p.Envelope{ + ChannelID: MempoolStateChannel, + Message: &protomem.SeenTx{TxKey: txKey[:]}, + } + + memR.peerIDs.Range(func(key, _ interface{}) bool { + peerID := key.(p2p.ID) + peer := memR.Switch.Peers().Get(peerID) + memR.Logger.Debug("Sending SeenTx...", "tx", txKey, "peer", peerID) + + if peerState, ok := peer.Get(types.PeerStateKey).(PeerState); ok { + // make sure peer isn't too far behind. This can happen + // if the peer is block-synching still and catching up + // in which case we just skip sending the transaction + if peerState.GetHeight() < memR.mempool.Height()-peerHeightDiff { + memR.Logger.Debug("peer is too far behind us. Skipping broadcast of seen tx") + return true + } + } + // no need to send a seen tx message to a peer that already + // has that tx. + if memR.seenByPeersSet.Has(txKey, peerID) { + memR.Logger.Debug("Peer has seen the transaction, not sending SeenTx", "tx", txKey, "peer", peerID) + return true + } + + if peer.Send(msg) { + memR.Logger.Debug("SeenTx sent", "tx", txKey, "peer", peerID) + return true + } + memR.Logger.Error("send SeenTx message failed", "txKey", txKey, "peerID", peer.ID()) + return true + }) +} + +// broadcastNewTx broadcast new transaction to all peers unless we are already +// sure they have seen the tx. +func (memR *Reactor) broadcastNewTx(entry *mempool.CListEntry) { + if !memR.Config.Broadcast { + return + } + + // If the node is catching up, don't start this routine immediately. + if memR.WaitSync() { + select { + case <-memR.WaitSyncChan(): + // EnableInOutTxs() has set WaitSync() to false. + case <-memR.Quit(): + return + } + } + memR.Logger.Info("Start tx propagation") + + tx := entry.Tx() + txKey := tx.Key() + memR.Logger.Debug("Broadcasting new transaction...", "tx", txKey.String()) + + msg := p2p.Envelope{ + ChannelID: mempool.MempoolChannel, + Message: &protomem.Txs{Txs: [][]byte{tx}}, + } + + memR.peerIDs.Range(func(key, _ interface{}) bool { + peerID := key.(p2p.ID) + peer := memR.Switch.Peers().Get(peerID) + memR.Logger.Debug("Sending new transaction to...", "tx", txKey, "peer", peerID) + + if peerState, ok := peer.Get(types.PeerStateKey).(PeerState); ok { + // make sure peer isn't too far behind. This can happen + // if the peer is blocksyncing still and catching up + // in which case we just skip sending the transaction + if peerState.GetHeight() < entry.Height()-peerHeightDiff { + memR.Logger.Debug("Peer is too far behind us, don't send new tx") + return true + } + } + + if memR.seenByPeersSet.Has(txKey, peerID) { + memR.Logger.Debug("Peer has seen the transaction, not sending it", "tx", txKey, "peer", peerID) + return true + } + + if peer.Send(msg) { + memR.Logger.Debug("New transaction sent", "tx", txKey, "peer", peerID) + memR.markPeerHasTx(peerID, txKey) + return true + } + + memR.Logger.Error("send Txs message with new transactions failed", "txKey", txKey, "peerID", peerID) + return true + }) +} + +// requestTx requests a transaction from a peer and tracks it, +// requesting it from another peer if the first peer does not respond. +func (memR *Reactor) requestTx(txKey types.TxKey, peerID p2p.ID) { + if !memR.Config.Broadcast { + return + } + + if !memR.Switch.Peers().Has(peerID) { + // we have disconnected from the peer + return + } + + memR.Logger.Debug("requesting tx", "txKey", txKey, "peerID", peerID) + peer := memR.Switch.Peers().Get(peerID) + msg := p2p.Envelope{ + ChannelID: MempoolStateChannel, + Message: &protomem.WantTx{TxKey: txKey[:]}, + } + if peer.Send(msg) { + memR.mempool.Metrics.RequestedTxs.Add(1) + added := memR.requests.Add(txKey, peerID, memR.findNewPeerToRequestTx) + if !added { + memR.Logger.Error("have already marked a tx as requested", "txKey", txKey, "peerID", peerID) + } + } else { + memR.Logger.Error("send WantTx message failed", "txKey", txKey, "peerID", peerID) + } +} + +// findNewPeerToSendTx finds a new peer that has already seen the transaction to +// request a transaction from. +func (memR *Reactor) findNewPeerToRequestTx(txKey types.TxKey) { + // ensure that we are connected to peers + if memR.Switch.Peers().Size() == 0 { + return + } + + // pop the next peer in the list of remaining peers that have seen the tx + // and does not already have an outbound request for that tx + seenMap := memR.seenByPeersSet.Get(txKey) + var peerID *p2p.ID + for possiblePeer := range seenMap { + possiblePeer := possiblePeer + if !memR.requests.Has(possiblePeer, txKey) { + peerID = &possiblePeer + break + } + } + + if peerID == nil { + // No other free peer has the transaction we are looking for. + // We give up 🤷‍♂️ and hope either a peer responds late or the tx + // is gossiped again + memR.mempool.Metrics.NoPeerForTx.Add(1) + memR.Logger.Info("no other peer has the tx we are looking for", "txKey", txKey) + return + } + + if !memR.Switch.Peers().Has(*peerID) { + // we disconnected from that peer, retry again until we exhaust the list + memR.findNewPeerToRequestTx(txKey) + } else { + memR.mempool.Metrics.RerequestedTxs.Add(1) + memR.requestTx(txKey, *peerID) + } +} diff --git a/mempool/cat/requests.go b/mempool/cat/requests.go new file mode 100644 index 00000000000..6f9f4377c34 --- /dev/null +++ b/mempool/cat/requests.go @@ -0,0 +1,153 @@ +package cat + +import ( + "sync" + "time" + + "github.com/cometbft/cometbft/p2p" + "github.com/cometbft/cometbft/types" +) + +const defaultGlobalRequestTimeout = 1 * time.Hour + +// requestScheduler tracks the lifecycle of outbound transaction requests. +type requestScheduler struct { + mtx sync.Mutex + + // responseTime is the time the scheduler + // waits for a response from a peer before + // invoking the callback + responseTime time.Duration + + // globalTimeout represents the longest duration + // to wait for any late response (after the responseTime). + // After this period the request is garbage collected. + globalTimeout time.Duration + + // requestsByPeer is a lookup table of requests by peer. + // Multiple transactions can be requested by a single peer at one + requestsByPeer map[p2p.ID]requestSet + + // requestsByTx is a lookup table for requested txs. + // There can only be one request per tx. + requestsByTx map[types.TxKey]p2p.ID +} + +type requestSet map[types.TxKey]*time.Timer + +func newRequestScheduler(responseTime, globalTimeout time.Duration) *requestScheduler { + return &requestScheduler{ + responseTime: responseTime, + globalTimeout: globalTimeout, + requestsByPeer: make(map[p2p.ID]requestSet), + requestsByTx: make(map[types.TxKey]p2p.ID), + } +} + +// Return true iff the pair (txKey, peerID) was successfully added to the scheduler. +func (r *requestScheduler) Add(txKey types.TxKey, peerID p2p.ID, onTimeout func(key types.TxKey)) bool { + r.mtx.Lock() + defer r.mtx.Unlock() + + // not allowed to have more than one outgoing transaction at once + if _, ok := r.requestsByTx[txKey]; ok { + return false + } + + timer := time.AfterFunc(r.responseTime, func() { + r.mtx.Lock() + delete(r.requestsByTx, txKey) + r.mtx.Unlock() + + // trigger callback. Callback can `Add` the tx back to the scheduler + if onTimeout != nil { + onTimeout(txKey) + } + + // We set another timeout because the peer could still send + // a late response after the first timeout and it's important + // to recognize that it is a transaction in response to a + // request and not a new transaction being broadcasted to the entire + // network. This timer cannot be stopped and is used to ensure + // garbage collection. + time.AfterFunc(r.globalTimeout, func() { + r.mtx.Lock() + defer r.mtx.Unlock() + delete(r.requestsByPeer[peerID], txKey) + }) + }) + if _, ok := r.requestsByPeer[peerID]; !ok { + r.requestsByPeer[peerID] = requestSet{txKey: timer} + } else { + r.requestsByPeer[peerID][txKey] = timer + } + r.requestsByTx[txKey] = peerID + return true +} + +func (r *requestScheduler) ForTx(key types.TxKey) (p2p.ID, bool) { + r.mtx.Lock() + defer r.mtx.Unlock() + + v, ok := r.requestsByTx[key] + return v, ok +} + +func (r *requestScheduler) Has(peer p2p.ID, key types.TxKey) bool { + r.mtx.Lock() + defer r.mtx.Unlock() + + requestSet, ok := r.requestsByPeer[peer] + if !ok { + return false + } + _, ok = requestSet[key] + return ok +} + +func (r *requestScheduler) ClearAllRequestsFrom(peer p2p.ID) requestSet { + r.mtx.Lock() + defer r.mtx.Unlock() + + requests, ok := r.requestsByPeer[peer] + if !ok { + return requestSet{} + } + for _, timer := range requests { + timer.Stop() + } + delete(r.requestsByPeer, peer) + return requests +} + +func (r *requestScheduler) MarkReceived(peer p2p.ID, key types.TxKey) bool { + r.mtx.Lock() + defer r.mtx.Unlock() + + if _, ok := r.requestsByPeer[peer]; !ok { + return false + } + + if timer, ok := r.requestsByPeer[peer][key]; ok { + timer.Stop() + } else { + return false + } + + delete(r.requestsByPeer[peer], key) + delete(r.requestsByTx, key) + return true +} + +// Close stops all timers and clears all requests. +// Add should never be called after `Close`. +func (r *requestScheduler) Close() { + r.mtx.Lock() + defer r.mtx.Unlock() + + for _, requestSet := range r.requestsByPeer { + for _, timer := range requestSet { + timer.Stop() + } + } +} diff --git a/mempool/cat/requests_test.go b/mempool/cat/requests_test.go new file mode 100644 index 00000000000..c0999ba6cfa --- /dev/null +++ b/mempool/cat/requests_test.go @@ -0,0 +1,147 @@ +package cat + +import ( + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/fortytw2/leaktest" + "github.com/stretchr/testify/require" + + "github.com/cometbft/cometbft/p2p" + "github.com/cometbft/cometbft/types" +) + +func TestRequestSchedulerRerequest(t *testing.T) { + var ( + requests = newRequestScheduler(10*time.Millisecond, 1*time.Minute) + tx = types.Tx("tx") + key = tx.Key() + peerA = p2p.ID("1") // should be non-zero + peerB = p2p.ID("2") + ) + t.Cleanup(requests.Close) + + // check zero state + _, exists := requests.ForTx(key) + require.False(t, exists) + require.False(t, requests.Has(peerA, key)) + // marking a tx that was never requested should return false + require.False(t, requests.MarkReceived(peerA, key)) + + // create a request + closeCh := make(chan struct{}) + require.True(t, requests.Add(key, peerA, func(key types.TxKey) { + require.Equal(t, key, key) + // the first peer times out to respond so we ask the second peer + require.True(t, requests.Add(key, peerB, func(key types.TxKey) { + t.Fatal("did not expect to timeout") + })) + close(closeCh) + })) + + // check that the request was added + peer, exists := requests.ForTx(key) + require.True(t, exists) + require.Equal(t, peerA, peer) + require.True(t, requests.Has(peerA, key)) + + // should not be able to add the same request again + require.False(t, requests.Add(key, peerA, nil)) + + // wait for the scheduler to invoke the timeout + <-closeCh + + // check that the request still exists + require.True(t, requests.Has(peerA, key)) + // check that peerB was requested + require.True(t, requests.Has(peerB, key)) + + // There should still be a request for the Tx + peer, exists = requests.ForTx(key) + require.True(t, exists) + require.Equal(t, peerB, peer) + + // record a response from peerB + require.True(t, requests.MarkReceived(peerB, key)) + + // peerA comes in later with a response but it's still + // considered a response from an earlier request + require.True(t, requests.MarkReceived(peerA, key)) +} + +func TestRequestSchedulerNonResponsivePeer(t *testing.T) { + var ( + requests = newRequestScheduler(10*time.Millisecond, time.Millisecond) + tx = types.Tx("tx") + key = tx.Key() + peerA = p2p.ID("1") // should be non-zero + ) + + require.True(t, requests.Add(key, peerA, nil)) + require.Eventually(t, func() bool { + _, exists := requests.ForTx(key) + return !exists + }, 100*time.Millisecond, 5*time.Millisecond) +} + +func TestRequestSchedulerConcurrencyAddsAndReads(t *testing.T) { + leaktest.CheckTimeout(t, time.Second)() + requests := newRequestScheduler(10*time.Millisecond, time.Millisecond) + defer requests.Close() + + N := 5 + keys := make([]types.TxKey, N) + for i := 0; i < N; i++ { + tx := types.Tx(fmt.Sprintf("tx%d", i)) + keys[i] = tx.Key() + } + + addWg := sync.WaitGroup{} + receiveWg := sync.WaitGroup{} + doneCh := make(chan struct{}) + for i := 1; i < N*N; i++ { + addWg.Add(1) + go func(i int) { + defer addWg.Done() + peerID := p2p.ID(strconv.Itoa(i)) + requests.Add(keys[i%N], peerID, nil) + }(i) + } + for i := 1; i < N*N; i++ { + receiveWg.Add(1) + go func(peer p2p.ID) { + defer receiveWg.Done() + markReceived := func() { + for _, key := range keys { + if requests.Has(peer, key) { + requests.MarkReceived(peer, key) + } + } + } + for { + select { + case <-doneCh: + // need to ensure this is run + // at least once after all adds + // are done + markReceived() + return + default: + markReceived() + } + } + }(p2p.ID(strconv.Itoa(i))) + } + addWg.Wait() + close(doneCh) + + receiveWg.Wait() + + for _, key := range keys { + _, exists := requests.ForTx(key) + require.False(t, exists) + } +} diff --git a/mempool/cat/seen_txs.go b/mempool/cat/seen_txs.go new file mode 100644 index 00000000000..759da07b4fa --- /dev/null +++ b/mempool/cat/seen_txs.go @@ -0,0 +1,127 @@ +package cat + +import ( + "time" + + cmtsync "github.com/cometbft/cometbft/internal/sync" + "github.com/cometbft/cometbft/p2p" + "github.com/cometbft/cometbft/types" +) + +// SeenTxSet records transactions that have been +// seen by other peers but not yet by us. +type SeenTxSet struct { + mtx cmtsync.Mutex + set map[types.TxKey]timestampedPeerSet +} + +type timestampedPeerSet struct { + peers map[p2p.ID]struct{} + time time.Time // time at which the set was created +} + +func NewSeenTxSet() *SeenTxSet { + return &SeenTxSet{ + set: make(map[types.TxKey]timestampedPeerSet), + } +} + +func (s *SeenTxSet) Add(txKey types.TxKey, peer p2p.ID) { + // if peer == 0 { + // return + // } + s.mtx.Lock() + defer s.mtx.Unlock() + seenSet, exists := s.set[txKey] + if !exists { + s.set[txKey] = timestampedPeerSet{ + peers: map[p2p.ID]struct{}{peer: {}}, + time: time.Now().UTC(), + } + } else { + seenSet.peers[peer] = struct{}{} + } +} + +// only used in tests. +func (s *SeenTxSet) Pop(txKey types.TxKey) *p2p.ID { + s.mtx.Lock() + defer s.mtx.Unlock() + seenSet, exists := s.set[txKey] + if exists { + for peer := range seenSet.peers { + delete(seenSet.peers, peer) + return &peer + } + } + return nil +} + +func (s *SeenTxSet) RemoveKey(txKey types.TxKey) { + s.mtx.Lock() + defer s.mtx.Unlock() + delete(s.set, txKey) +} + +func (s *SeenTxSet) Remove(txKey types.TxKey, peer p2p.ID) { + s.mtx.Lock() + defer s.mtx.Unlock() + set, exists := s.set[txKey] + if exists { + if len(set.peers) == 1 { + delete(s.set, txKey) + } else { + delete(set.peers, peer) + } + } +} + +// not used +// func (s *SeenTxSet) Prune(limit time.Time) { +// s.mtx.Lock() +// defer s.mtx.Unlock() +// for key, seenSet := range s.set { +// if seenSet.time.Before(limit) { +// delete(s.set, key) +// } +// } +// } + +func (s *SeenTxSet) Has(txKey types.TxKey, peer p2p.ID) bool { + s.mtx.Lock() + defer s.mtx.Unlock() + seenSet, exists := s.set[txKey] + if !exists { + return false + } + _, has := seenSet.peers[peer] + return has +} + +func (s *SeenTxSet) Get(txKey types.TxKey) map[p2p.ID]struct{} { + s.mtx.Lock() + defer s.mtx.Unlock() + seenSet, exists := s.set[txKey] + if !exists { + return nil + } + // make a copy of the struct to avoid concurrency issues + peers := make(map[p2p.ID]struct{}, len(seenSet.peers)) + for peer := range seenSet.peers { + peers[peer] = struct{}{} + } + return peers +} + +// Len returns the amount of cached items. Mostly used for testing. +func (s *SeenTxSet) Len() int { + s.mtx.Lock() + defer s.mtx.Unlock() + return len(s.set) +} + +func (s *SeenTxSet) Reset() { + s.mtx.Lock() + defer s.mtx.Unlock() + s.set = make(map[types.TxKey]timestampedPeerSet) +} diff --git a/mempool/cat/seen_txs_test.go b/mempool/cat/seen_txs_test.go new file mode 100644 index 00000000000..015cea6d53d --- /dev/null +++ b/mempool/cat/seen_txs_test.go @@ -0,0 +1,86 @@ +package cat + +import ( + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/cometbft/cometbft/p2p" + "github.com/cometbft/cometbft/types" +) + +func TestSeenTxSet(t *testing.T) { + var ( + tx1Key = types.Tx("tx1").Key() + tx2Key = types.Tx("tx2").Key() + tx3Key = types.Tx("tx3").Key() + peer1 = p2p.ID("1") + peer2 = p2p.ID("2") + ) + + seenSet := NewSeenTxSet() + require.Zero(t, seenSet.Pop(tx1Key)) + + seenSet.Add(tx1Key, peer1) + seenSet.Add(tx1Key, peer1) + require.Equal(t, 1, seenSet.Len()) + seenSet.Add(tx1Key, peer2) + peers := seenSet.Get(tx1Key) + require.NotNil(t, peers) + require.Equal(t, map[p2p.ID]struct{}{peer1: {}, peer2: {}}, peers) + seenSet.Add(tx2Key, peer1) + seenSet.Add(tx3Key, peer1) + require.Equal(t, 3, seenSet.Len()) + seenSet.RemoveKey(tx2Key) + require.Equal(t, 2, seenSet.Len()) + require.Zero(t, seenSet.Pop(tx2Key)) + require.Equal(t, peer1, *seenSet.Pop(tx3Key)) +} + +func TestSeenTxSetConcurrency(_ *testing.T) { + seenSet := NewSeenTxSet() + + const ( + concurrency = 10 + numTx = 100 + ) + + wg := sync.WaitGroup{} + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(i uint16) { + defer wg.Done() + for i := 0; i < numTx; i++ { + tx := types.Tx([]byte(fmt.Sprintf("tx%d", i))) + seenSet.Add(tx.Key(), p2p.ID(strconv.Itoa(i))) + } + }(uint16(i % 2)) + } + time.Sleep(time.Millisecond) + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(peer uint16) { + defer wg.Done() + for i := 0; i < numTx; i++ { + tx := types.Tx([]byte(fmt.Sprintf("tx%d", i))) + seenSet.Has(tx.Key(), p2p.ID(strconv.Itoa(i))) + } + }(uint16(i % 2)) + } + time.Sleep(time.Millisecond) + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(peer uint16) { + defer wg.Done() + for i := numTx - 1; i >= 0; i-- { + tx := types.Tx([]byte(fmt.Sprintf("tx%d", i))) + seenSet.RemoveKey(tx.Key()) + } + }(uint16(i % 2)) + } + wg.Wait() +} diff --git a/mempool/clist_mempool.go b/mempool/clist_mempool.go index d499b3ed90a..253539e67c4 100644 --- a/mempool/clist_mempool.go +++ b/mempool/clist_mempool.go @@ -18,6 +18,50 @@ import ( "github.com/cometbft/cometbft/types" ) +// mempoolTx is an entry in the clist. +type mempoolTx struct { + tx types.Tx // validated by the application + height int64 // height that this tx had been validated in + gasWanted int64 // amount of gas this tx states it will require +} + +func (memTx *mempoolTx) Height() int64 { + return atomic.LoadInt64(&memTx.height) +} + +func (memTx *mempoolTx) Tx() types.Tx { + return memTx.tx +} + +// CElement wrapper. +type CListEntry struct { + elem *clist.CElement +} + +func (e *CListEntry) IsEmpty() bool { + return e == nil || e.elem == nil +} + +func (e *CListEntry) Tx() types.Tx { + return e.elem.Value.(*mempoolTx).tx +} + +func (e *CListEntry) Height() int64 { + return e.elem.Value.(*mempoolTx).Height() +} + +func (e *CListEntry) GasWanted() int64 { + return e.elem.Value.(*mempoolTx).gasWanted +} + +func (e *CListEntry) NextWaitChan() <-chan struct{} { + return e.elem.NextWaitChan() +} + +func (e *CListEntry) Next() *CListEntry { + return &CListEntry{e.elem.Next()} +} + // CListMempool is an ordered in-memory pool for transactions before they are // proposed in a consensus round. Transaction validity is checked using the // CheckTx abci message before the transaction is added to the pool. The @@ -36,6 +80,10 @@ type CListMempool struct { // from the mempool. removeTxOnReactorCb func(txKey types.TxKey) + // Function set by the reactor to be called when a new transaction is added + // to the mempool. Used only by the CAT mempool reactor. + newTxReceivedCb func(txKey types.TxKey) + config *config.MempoolConfig // Exclusive mutex for Update method to prevent concurrent execution of @@ -60,10 +108,13 @@ type CListMempool struct { // Keep a cache of already-seen txs. // This reduces the pressure on the proxyApp. - cache TxCache + cache TxCache[types.TxKey] + + // Keep a cache of invalid transactions. + rejectedTxsCache TxCache[types.TxKey] logger log.Logger - metrics *Metrics + Metrics *Metrics } var _ Mempool = &CListMempool{} @@ -80,20 +131,21 @@ func NewCListMempool( options ...CListMempoolOption, ) *CListMempool { mp := &CListMempool{ - config: cfg, - proxyAppConn: proxyAppConn, - txs: clist.New(), - height: height, - recheckCursor: nil, - recheckEnd: nil, - logger: log.NewNopLogger(), - metrics: NopMetrics(), + config: cfg, + proxyAppConn: proxyAppConn, + txs: clist.New(), + height: height, + recheckCursor: nil, + recheckEnd: nil, + rejectedTxsCache: NewLRUTxCache[types.TxKey](cfg.CacheSize), // TODO: check size + logger: log.NewNopLogger(), + Metrics: NopMetrics(), } if cfg.CacheSize > 0 { - mp.cache = NewLRUTxCache(cfg.CacheSize) + mp.cache = NewLRUTxCache[types.TxKey](cfg.CacheSize) } else { - mp.cache = NopTxCache{} + mp.cache = NopTxCache[types.TxKey]{} } proxyAppConn.SetResponseCallback(mp.globalCb) @@ -105,32 +157,36 @@ func NewCListMempool( return mp } -func (mem *CListMempool) getCElement(txKey types.TxKey) (*clist.CElement, bool) { - if e, ok := mem.txsMap.Load(txKey); ok { - return e.(*clist.CElement), true +func (mem *CListMempool) GetEntry(txKey types.TxKey) *CListEntry { + if elem, ok := mem.txsMap.Load(txKey); ok { + return &CListEntry{elem.(*clist.CElement)} } - return nil, false + return nil } func (mem *CListMempool) InMempool(txKey types.TxKey) bool { - _, ok := mem.getCElement(txKey) - return ok + elem := mem.GetEntry(txKey) + return elem != nil } -func (mem *CListMempool) addToCache(tx types.Tx) bool { - return mem.cache.Push(tx) +func (mem *CListMempool) addToCache(txKey types.TxKey) bool { + return mem.cache.Push(txKey) } -func (mem *CListMempool) forceRemoveFromCache(tx types.Tx) { - mem.cache.Remove(tx) +func (mem *CListMempool) InCache(txKey types.TxKey) bool { + return mem.cache.Has(txKey) +} + +func (mem *CListMempool) forceRemoveFromCache(txKey types.TxKey) { + mem.cache.Remove(txKey) } // tryRemoveFromCache removes a transaction from the cache in case it can be // added to the mempool at a later stage (probably when the transaction becomes // valid). -func (mem *CListMempool) tryRemoveFromCache(tx types.Tx) { +func (mem *CListMempool) tryRemoveFromCache(txKey types.TxKey) { if !mem.config.KeepInvalidTxsInCache { - mem.forceRemoveFromCache(tx) + mem.forceRemoveFromCache(txKey) } } @@ -147,6 +203,10 @@ func (mem *CListMempool) removeAllTxs() { }) } +func (mem *CListMempool) WasRejected(txKey types.TxKey) bool { + return mem.rejectedTxsCache.Has(txKey) +} + // NOTE: not thread safe - should only be called once, on startup. func (mem *CListMempool) EnableTxsAvailable() { mem.txsAvailable = make(chan struct{}, 1) @@ -185,7 +245,7 @@ func WithPostCheck(f PostCheckFunc) CListMempoolOption { // WithMetrics sets the metrics. func WithMetrics(metrics *Metrics) CListMempoolOption { - return func(mem *CListMempool) { mem.metrics = metrics } + return func(mem *CListMempool) { mem.Metrics = metrics } } // Safe for concurrent use by multiple goroutines. @@ -229,13 +289,20 @@ func (mem *CListMempool) Flush() { mem.removeAllTxs() } +// Height returns the latest height that the mempool is at. +func (mem *CListMempool) Height() int64 { + mem.updateMtx.Lock() + defer mem.updateMtx.Unlock() + return mem.height +} + // TxsFront returns the first transaction in the ordered list for peer // goroutines to call .NextWait() on. // FIXME: leaking implementation details! // // Safe for concurrent use by multiple goroutines. -func (mem *CListMempool) TxsFront() *clist.CElement { - return mem.txs.Front() +func (mem *CListMempool) TxsFront() *CListEntry { + return &CListEntry{mem.txs.Front()} } // TxsWaitChan returns a channel to wait on transactions. It will be closed @@ -280,9 +347,9 @@ func (mem *CListMempool) CheckTx(tx types.Tx) (*abcicli.ReqRes, error) { return nil, ErrAppConnMempool{Err: err} } - if added := mem.addToCache(tx); !added { + if added := mem.addToCache(tx.Key()); !added { mem.logger.Debug("Not cached", "tx", tx) - mem.metrics.AlreadyReceivedTxs.Add(1) + mem.Metrics.AlreadyReceivedTxs.Add(1) // TODO: consider punishing peer for dups, // its non-trivial since invalid txs can become valid, // but they can spam the same tx with little cost to them atm. @@ -302,6 +369,22 @@ func (mem *CListMempool) CheckTx(tx types.Tx) (*abcicli.ReqRes, error) { return reqRes, nil } +func (mem *CListMempool) SetNewTxReceivedCallback(cb func(txKey types.TxKey)) { + mem.newTxReceivedCb = cb +} + +func (mem *CListMempool) InvokeNewTxReceivedOnReactor(txKey types.TxKey) { + if mem.newTxReceivedCb != nil { + mem.newTxReceivedCb(txKey) + } +} + +func (mem *CListMempool) CheckNewTx(tx types.Tx) (*abcicli.ReqRes, error) { + mem.logger.Debug("Tx received from RPC endpoint", "tx", tx.Key().String()) + reqRes, err := mem.CheckTx(tx) + return reqRes, err +} + // Global callback that will be called after every ABCI response. func (mem *CListMempool) globalCb(req *abci.Request, res *abci.Response) { switch res.Value.(type) { @@ -319,7 +402,7 @@ func (mem *CListMempool) globalCb(req *abci.Request, res *abci.Response) { if mem.recheckCursor == nil { return } - mem.metrics.RecheckTimes.Add(1) + mem.Metrics.RecheckTimes.Add(1) mem.resCbRecheck(req, res) default: @@ -327,8 +410,8 @@ func (mem *CListMempool) globalCb(req *abci.Request, res *abci.Response) { } // update metrics - mem.metrics.Size.Set(float64(mem.Size())) - mem.metrics.SizeBytes.Set(float64(mem.SizeBytes())) + mem.Metrics.Size.Set(float64(mem.Size())) + mem.Metrics.SizeBytes.Set(float64(mem.SizeBytes())) default: // ignore other messages @@ -341,7 +424,7 @@ func (mem *CListMempool) addTx(memTx *mempoolTx) { e := mem.txs.PushBack(memTx) mem.txsMap.Store(memTx.tx.Key(), e) atomic.AddInt64(&mem.txsBytes, int64(len(memTx.tx))) - mem.metrics.TxSizeBytes.Observe(float64(len(memTx.tx))) + mem.Metrics.TxSizeBytes.Observe(float64(len(memTx.tx))) mem.logger.Debug("Clisted", "tx", memTx.tx) } @@ -353,12 +436,11 @@ func (mem *CListMempool) RemoveTxByKey(txKey types.TxKey) error { // The transaction should be removed from the reactor, even if it cannot be // found in the mempool. mem.invokeRemoveTxOnReactor(txKey) - if elem, ok := mem.getCElement(txKey); ok { - mem.txs.Remove(elem) - elem.DetachPrev() + if entry := mem.GetEntry(txKey); entry != nil { + mem.txs.Remove(entry.elem) + entry.elem.DetachPrev() mem.txsMap.Delete(txKey) - tx := elem.Value.(*mempoolTx).tx - atomic.AddInt64(&mem.txsBytes, int64(-len(tx))) + atomic.AddInt64(&mem.txsBytes, int64(-len(entry.Tx()))) return nil } return ErrTxNotFound @@ -401,7 +483,7 @@ func (mem *CListMempool) resCbFirstTime( // Check mempool isn't full again to reduce the chance of exceeding the // limits. if err := mem.isFull(len(tx)); err != nil { - mem.forceRemoveFromCache(tx) // mempool might have space later + mem.forceRemoveFromCache(txKey) // mempool might have space later mem.logger.Error(err.Error()) return } @@ -432,14 +514,15 @@ func (mem *CListMempool) resCbFirstTime( ) mem.notifyTxsAvailable() } else { - mem.tryRemoveFromCache(tx) + mem.tryRemoveFromCache(txKey) + mem.rejectedTxsCache.Push(txKey) mem.logger.Debug( "rejected invalid transaction", "tx", types.Tx(tx).Hash(), "res", r, "err", postCheckErr, ) - mem.metrics.FailedTxs.Add(1) + mem.Metrics.FailedTxs.Add(1) } default: @@ -454,7 +537,7 @@ func (mem *CListMempool) resCbFirstTime( func (mem *CListMempool) resCbRecheck(req *abci.Request, res *abci.Response) { switch r := res.Value.(type) { case *abci.Response_CheckTx: - tx := req.GetCheckTx().Tx + tx := types.Tx(req.GetCheckTx().Tx) memTx := mem.recheckCursor.Value.(*mempoolTx) // Search through the remaining list of tx to recheck for a transaction that matches @@ -469,7 +552,7 @@ func (mem *CListMempool) resCbRecheck(req *abci.Request, res *abci.Response) { mem.logger.Error( "re-CheckTx transaction mismatch", - "got", types.Tx(tx), + "got", tx, "expected", memTx.tx, ) @@ -492,11 +575,12 @@ func (mem *CListMempool) resCbRecheck(req *abci.Request, res *abci.Response) { if (r.CheckTx.Code != abci.CodeTypeOK) || postCheckErr != nil { // Tx became invalidated due to newly committed block. - mem.logger.Debug("tx is no longer valid", "tx", types.Tx(tx).Hash(), "res", r, "err", postCheckErr) - if err := mem.RemoveTxByKey(memTx.tx.Key()); err != nil { + mem.logger.Debug("tx is no longer valid", "tx", tx.Hash(), "res", r, "err", postCheckErr) + if err := mem.RemoveTxByKey(tx.Key()); err != nil { mem.logger.Debug("Transaction could not be removed from mempool", "err", err) } - mem.tryRemoveFromCache(tx) + mem.tryRemoveFromCache(tx.Key()) + mem.rejectedTxsCache.Push(tx.Key()) } if mem.recheckCursor == mem.recheckEnd { mem.recheckCursor = nil @@ -615,11 +699,13 @@ func (mem *CListMempool) Update( } for i, tx := range txs { + txKey := tx.Key() if txResults[i].Code == abci.CodeTypeOK { // Add valid committed tx to the cache (if missing). - _ = mem.addToCache(tx) + _ = mem.addToCache(txKey) } else { - mem.tryRemoveFromCache(tx) + mem.tryRemoveFromCache(txKey) + mem.rejectedTxsCache.Push(tx.Key()) } // Remove committed tx from the mempool. @@ -654,8 +740,8 @@ func (mem *CListMempool) Update( } // Update metrics - mem.metrics.Size.Set(float64(mem.Size())) - mem.metrics.SizeBytes.Set(float64(mem.SizeBytes())) + mem.Metrics.Size.Set(float64(mem.Size())) + mem.Metrics.SizeBytes.Set(float64(mem.SizeBytes())) return nil } diff --git a/mempool/clist_mempool_test.go b/mempool/clist_mempool_test.go index bfb4dad6069..b2831db3dd0 100644 --- a/mempool/clist_mempool_test.go +++ b/mempool/clist_mempool_test.go @@ -142,10 +142,10 @@ func TestReapMaxBytesMaxGas(t *testing.T) { // Ensure gas calculation behaves as expected checkTxs(t, mp, 1) - tx0 := mp.TxsFront().Value.(*mempoolTx) - require.Equal(t, tx0.gasWanted, int64(1), "transactions gas was set incorrectly") + tx0 := mp.TxsFront() + require.Equal(t, tx0.GasWanted(), int64(1), "transactions gas was set incorrectly") // ensure each tx is 20 bytes long - require.Len(t, tx0.tx, 20, "Tx is longer than 20 bytes") + require.Len(t, tx0.Tx(), 20, "Tx is longer than 20 bytes") mp.Flush() // each table driven test creates numTxsToCreate txs with checkTx, and at the end clears all remaining txs. @@ -354,7 +354,7 @@ func TestMempool_KeepInvalidTxsInCache(t *testing.T) { binary.BigEndian.PutUint64(a, 0) // remove a from the cache to test (2) - mp.cache.Remove(a) + mp.cache.Remove(types.Tx(a).Key()) _, err := mp.CheckTx(a) require.NoError(t, err) @@ -664,7 +664,7 @@ func TestMempoolNoCacheOverflow(t *testing.T) { defer cleanup() // add tx0 - tx0 := kvstore.NewTxFromID(0) + tx0 := types.Tx(kvstore.NewTxFromID(0)) _, err := mp.CheckTx(tx0) require.NoError(t, err) err = mp.FlushAppConn() @@ -677,7 +677,7 @@ func TestMempoolNoCacheOverflow(t *testing.T) { } err = mp.FlushAppConn() require.NoError(t, err) - assert.False(t, mp.cache.Has(kvstore.NewTxFromID(0))) + assert.False(t, mp.cache.Has(tx0.Key())) // add again tx0 _, err = mp.CheckTx(tx0) diff --git a/mempool/mempool.go b/mempool/mempool.go index c6a89697ddd..dd03f88b6c9 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -27,6 +27,11 @@ type Mempool interface { // its validity and whether it should be added to the mempool. CheckTx(tx types.Tx) (*abcicli.ReqRes, error) + // From RPC endpoint + CheckNewTx(tx types.Tx) (*abcicli.ReqRes, error) + + InvokeNewTxReceivedOnReactor(txKey types.TxKey) + // RemoveTxByKey removes a transaction, identified by its key, // from the mempool. RemoveTxByKey(txKey types.TxKey) error diff --git a/mempool/mempoolTx.go b/mempool/mempoolTx.go deleted file mode 100644 index 48313461a96..00000000000 --- a/mempool/mempoolTx.go +++ /dev/null @@ -1,19 +0,0 @@ -package mempool - -import ( - "sync/atomic" - - "github.com/cometbft/cometbft/types" -) - -// mempoolTx is an entry in the mempool. -type mempoolTx struct { - height int64 // height that this tx had been validated in - gasWanted int64 // amount of gas this tx states it will require - tx types.Tx // validated by the application -} - -// Height returns the height for this transaction. -func (memTx *mempoolTx) Height() int64 { - return atomic.LoadInt64(&memTx.height) -} diff --git a/mempool/metrics.gen.go b/mempool/metrics.gen.go index 6714c739711..6676b8ade6e 100644 --- a/mempool/metrics.gen.go +++ b/mempool/metrics.gen.go @@ -64,6 +64,24 @@ func PrometheusMetrics(namespace string, labelsAndValues ...string) *Metrics { Name: "active_outbound_connections", Help: "Number of connections being actively used for gossiping transactions (experimental feature).", }, labels).With(labelsAndValues...), + RequestedTxs: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + Namespace: namespace, + Subsystem: MetricsSubsystem, + Name: "requested_txs", + Help: "Number of requested transactions (WantTx messages).", + }, labels).With(labelsAndValues...), + RerequestedTxs: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + Namespace: namespace, + Subsystem: MetricsSubsystem, + Name: "rerequested_txs", + Help: "Number of re-requested transactions.", + }, labels).With(labelsAndValues...), + NoPeerForTx: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + Namespace: namespace, + Subsystem: MetricsSubsystem, + Name: "no_peer_for_tx", + Help: "Number of times we cannot find a peer for a tx.", + }, labels).With(labelsAndValues...), } } @@ -77,5 +95,8 @@ func NopMetrics() *Metrics { RecheckTimes: discard.NewCounter(), AlreadyReceivedTxs: discard.NewCounter(), ActiveOutboundConnections: discard.NewGauge(), + RequestedTxs: discard.NewCounter(), + RerequestedTxs: discard.NewCounter(), + NoPeerForTx: discard.NewCounter(), } } diff --git a/mempool/metrics.go b/mempool/metrics.go index e8148c9fb06..1b32015389f 100644 --- a/mempool/metrics.go +++ b/mempool/metrics.go @@ -44,4 +44,20 @@ type Metrics struct { // Number of connections being actively used for gossiping transactions // (experimental feature). ActiveOutboundConnections metrics.Gauge + + // RequestedTxs defines the number of times that the node requested a + // tx to a peer + // metrics:Number of requested transactions (WantTx messages). + RequestedTxs metrics.Counter + + // RerequestedTxs defines the number of times that a requested tx + // never received a response in time and a new request was made. + // metrics:Number of re-requested transactions. + RerequestedTxs metrics.Counter + + // NoPeerForTx counts the number of times the reactor exhaust the list of + // peers looking for a transaction for which it has received a SeenTx + // message. + // metrics:Number of times we cannot find a peer for a tx. + NoPeerForTx metrics.Counter } diff --git a/mempool/mocks/mempool.go b/mempool/mocks/mempool.go index 4a2aba0a218..3c248291afc 100644 --- a/mempool/mocks/mempool.go +++ b/mempool/mocks/mempool.go @@ -18,6 +18,36 @@ type Mempool struct { mock.Mock } +// CheckNewTx provides a mock function with given fields: tx +func (_m *Mempool) CheckNewTx(tx types.Tx) (*abcicli.ReqRes, error) { + ret := _m.Called(tx) + + if len(ret) == 0 { + panic("no return value specified for CheckNewTx") + } + + var r0 *abcicli.ReqRes + var r1 error + if rf, ok := ret.Get(0).(func(types.Tx) (*abcicli.ReqRes, error)); ok { + return rf(tx) + } + if rf, ok := ret.Get(0).(func(types.Tx) *abcicli.ReqRes); ok { + r0 = rf(tx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*abcicli.ReqRes) + } + } + + if rf, ok := ret.Get(1).(func(types.Tx) error); ok { + r1 = rf(tx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // CheckTx provides a mock function with given fields: tx func (_m *Mempool) CheckTx(tx types.Tx) (*abcicli.ReqRes, error) { ret := _m.Called(tx) @@ -76,6 +106,11 @@ func (_m *Mempool) FlushAppConn() error { return r0 } +// InvokeNewTxReceivedOnReactor provides a mock function with given fields: txKey +func (_m *Mempool) InvokeNewTxReceivedOnReactor(txKey types.TxKey) { + _m.Called(txKey) +} + // Lock provides a mock function with given fields: func (_m *Mempool) Lock() { _m.Called() diff --git a/mempool/nop_mempool.go b/mempool/nop_mempool.go index 06c805a8f4f..1644e555ee3 100644 --- a/mempool/nop_mempool.go +++ b/mempool/nop_mempool.go @@ -26,6 +26,12 @@ func (*NopMempool) CheckTx(types.Tx) (*abcicli.ReqRes, error) { return nil, errNotAllowed } +func (*NopMempool) CheckNewTx(types.Tx) (*abcicli.ReqRes, error) { + return nil, errNotAllowed +} + +func (*NopMempool) InvokeNewTxReceivedOnReactor(txKey types.TxKey) {} + // RemoveTxByKey always returns an error. func (*NopMempool) RemoveTxByKey(types.TxKey) error { return errNotAllowed } diff --git a/mempool/reactor.go b/mempool/reactor.go index 302c916fb7f..899afb67153 100644 --- a/mempool/reactor.go +++ b/mempool/reactor.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "sync/atomic" "time" "golang.org/x/sync/semaphore" @@ -12,7 +11,6 @@ import ( abci "github.com/cometbft/cometbft/abci/types" protomem "github.com/cometbft/cometbft/api/cometbft/mempool/v1" cfg "github.com/cometbft/cometbft/config" - "github.com/cometbft/cometbft/internal/clist" cmtsync "github.com/cometbft/cometbft/internal/sync" "github.com/cometbft/cometbft/libs/log" "github.com/cometbft/cometbft/p2p" @@ -22,14 +20,11 @@ import ( // Reactor handles mempool tx broadcasting amongst peers. // It maintains a map from peer ID to counter, to prevent gossiping txs to the // peers you received it from. +// TODO: move file to its own package. type Reactor struct { - p2p.BaseReactor - config *cfg.MempoolConfig + WaitSyncReactor mempool *CListMempool - waitSync atomic.Bool - waitSyncCh chan struct{} // for signaling when to start receiving and sending txs - // `txSenders` maps every received transaction to the set of peer IDs that // have sent the transaction to this node. Sender IDs are used during // transaction propagation to avoid sending a transaction to a peer that @@ -45,21 +40,18 @@ type Reactor struct { } // NewReactor returns a new Reactor with the given config and mempool. -func NewReactor(config *cfg.MempoolConfig, mempool *CListMempool, waitSync bool) *Reactor { +func NewReactor(config *cfg.MempoolConfig, mempool *CListMempool, waitSync bool, logger log.Logger) *Reactor { memR := &Reactor{ - config: config, - mempool: mempool, - waitSync: atomic.Bool{}, - txSenders: make(map[types.TxKey]map[p2p.ID]bool), + WaitSyncReactor: *NewWaitSyncReactor(config, waitSync), + mempool: mempool, + txSenders: make(map[types.TxKey]map[p2p.ID]bool), } memR.BaseReactor = *p2p.NewBaseReactor("Mempool", memR) - if waitSync { - memR.waitSync.Store(true) - memR.waitSyncCh = make(chan struct{}) - } + + memR.SetLogger(logger) memR.mempool.SetTxRemovedCallback(func(txKey types.TxKey) { memR.removeSenders(txKey) }) - memR.activePersistentPeersSemaphore = semaphore.NewWeighted(int64(memR.config.ExperimentalMaxGossipConnectionsToPersistentPeers)) - memR.activeNonPersistentPeersSemaphore = semaphore.NewWeighted(int64(memR.config.ExperimentalMaxGossipConnectionsToNonPersistentPeers)) + memR.activePersistentPeersSemaphore = semaphore.NewWeighted(int64(config.ExperimentalMaxGossipConnectionsToPersistentPeers)) + memR.activeNonPersistentPeersSemaphore = semaphore.NewWeighted(int64(config.ExperimentalMaxGossipConnectionsToNonPersistentPeers)) return memR } @@ -75,7 +67,7 @@ func (memR *Reactor) OnStart() error { if memR.WaitSync() { memR.Logger.Info("Starting reactor in sync mode: tx propagation will start once sync completes") } - if !memR.config.Broadcast { + if !memR.Config.Broadcast { memR.Logger.Info("Tx broadcasting is disabled") } return nil @@ -84,7 +76,7 @@ func (memR *Reactor) OnStart() error { // GetChannels implements Reactor by returning the list of channels for this // reactor. func (memR *Reactor) GetChannels() []*p2p.ChannelDescriptor { - largestTx := make([]byte, memR.config.MaxTxBytes) + largestTx := make([]byte, memR.Config.MaxTxBytes) batchMsg := protomem.Message{ Sum: &protomem.Message_Txs{ Txs: &protomem.Txs{Txs: [][]byte{largestTx}}, @@ -104,15 +96,15 @@ func (memR *Reactor) GetChannels() []*p2p.ChannelDescriptor { // AddPeer implements Reactor. // It starts a broadcast routine ensuring all txs are forwarded to the given peer. func (memR *Reactor) AddPeer(peer p2p.Peer) { - if memR.config.Broadcast { + if memR.Config.Broadcast { go func() { // Always forward transactions to unconditional peers. if !memR.Switch.IsPeerUnconditional(peer.ID()) { // Depending on the type of peer, we choose a semaphore to limit the gossiping peers. var peerSemaphore *semaphore.Weighted - if peer.IsPersistent() && memR.config.ExperimentalMaxGossipConnectionsToPersistentPeers > 0 { + if peer.IsPersistent() && memR.Config.ExperimentalMaxGossipConnectionsToPersistentPeers > 0 { peerSemaphore = memR.activePersistentPeersSemaphore - } else if !peer.IsPersistent() && memR.config.ExperimentalMaxGossipConnectionsToNonPersistentPeers > 0 { + } else if !peer.IsPersistent() && memR.Config.ExperimentalMaxGossipConnectionsToNonPersistentPeers > 0 { peerSemaphore = memR.activeNonPersistentPeersSemaphore } @@ -137,8 +129,8 @@ func (memR *Reactor) AddPeer(peer p2p.Peer) { } } - memR.mempool.metrics.ActiveOutboundConnections.Add(1) - defer memR.mempool.metrics.ActiveOutboundConnections.Add(-1) + memR.mempool.Metrics.ActiveOutboundConnections.Add(1) + defer memR.mempool.Metrics.ActiveOutboundConnections.Add(-1) memR.broadcastTxRoutine(peer) }() } @@ -192,22 +184,6 @@ func (memR *Reactor) Receive(e p2p.Envelope) { // broadcasting happens from go routines per peer } -func (memR *Reactor) EnableInOutTxs() { - memR.Logger.Info("enabling inbound and outbound transactions") - if !memR.waitSync.CompareAndSwap(true, false) { - return - } - - // Releases all the blocked broadcastTxRoutine instances. - if memR.config.Broadcast { - close(memR.waitSyncCh) - } -} - -func (memR *Reactor) WaitSync() bool { - return memR.waitSync.Load() -} - // PeerState describes the state of a peer. type PeerState interface { GetHeight() int64 @@ -215,17 +191,18 @@ type PeerState interface { // Send new mempool txs to peer. func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { - var next *clist.CElement + var next *CListEntry // If the node is catching up, don't start this routine immediately. if memR.WaitSync() { select { - case <-memR.waitSyncCh: + case <-memR.WaitSyncChan(): // EnableInOutTxs() has set WaitSync() to false. case <-memR.Quit(): return } } + memR.Logger.Info("Start tx propagation", "peer", peer.ID()) for { // In case of both next.NextWaitChan() and peer.Quit() are variable at the same time @@ -236,10 +213,10 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { // This happens because the CElement we were looking at got garbage // collected (removed). That is, .NextWait() returned nil. Go ahead and // start from the beginning. - if next == nil { + if next.IsEmpty() { select { case <-memR.mempool.TxsWaitChan(): // Wait until a tx is available - if next = memR.mempool.TxsFront(); next == nil { + if next = memR.mempool.TxsFront(); next.IsEmpty() { continue } case <-peer.Quit(): @@ -267,8 +244,7 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { // node. See [RFC 103] for an analysis on this optimization. // // [RFC 103]: https://github.com/cometbft/cometbft/pull/735 - memTx := next.Value.(*mempoolTx) - if peerState.GetHeight() < memTx.Height()-1 { + if peerState.GetHeight() < next.Height()-1 { time.Sleep(PeerCatchupSleepIntervalMS * time.Millisecond) continue } @@ -276,10 +252,11 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { // NOTE: Transaction batching was disabled due to // https://github.com/tendermint/tendermint/issues/5796 - if !memR.isSender(memTx.tx.Key(), peer.ID()) { + tx := next.Tx() + if !memR.isSender(tx.Key(), peer.ID()) { success := peer.Send(p2p.Envelope{ ChannelID: MempoolChannel, - Message: &protomem.Txs{Txs: [][]byte{memTx.tx}}, + Message: &protomem.Txs{Txs: [][]byte{tx}}, }) if !success { time.Sleep(PeerCatchupSleepIntervalMS * time.Millisecond) diff --git a/mempool/reactor_test.go b/mempool/reactor_test.go index dc8c7571621..4eba1fb8586 100644 --- a/mempool/reactor_test.go +++ b/mempool/reactor_test.go @@ -561,8 +561,7 @@ func makeAndConnectReactors(config *cfg.Config, n int) ([]*Reactor, []*p2p.Switc mempool, cleanup := newMempoolWithApp(cc) defer cleanup() - reactors[i] = NewReactor(config.Mempool, mempool, false) // so we dont start the consensus states - reactors[i].SetLogger(logger.With("validator", i)) + reactors[i] = NewReactor(config.Mempool, mempool, false, logger.With("validator", i)) // so we don't start the consensus states } switches := p2p.MakeConnectedSwitches(config.P2P, n, func(i int, s *p2p.Switch) *p2p.Switch { @@ -582,8 +581,7 @@ func makeAndConnectReactorsStar(config *cfg.Config, c, n int) ([]*Reactor, []*p2 mempool, cleanup := newMempoolWithApp(cc) defer cleanup() - reactors[i] = NewReactor(config.Mempool, mempool, false) // so we dont start the consensus states - reactors[i].SetLogger(logger.With("validator", i)) + reactors[i] = NewReactor(config.Mempool, mempool, false, logger.With("validator", i)) // so we dont start the consensus states } switches := p2p.MakeConnectedSwitches(config.P2P, n, func(i int, s *p2p.Switch) *p2p.Switch { diff --git a/mempool/sync_reactor.go b/mempool/sync_reactor.go new file mode 100644 index 00000000000..72c0444bf49 --- /dev/null +++ b/mempool/sync_reactor.go @@ -0,0 +1,47 @@ +package mempool + +import ( + "sync/atomic" + + cfg "github.com/cometbft/cometbft/config" + "github.com/cometbft/cometbft/p2p" +) + +// Base mempool reactor with a configuration. It must implement the WaitSyncP2PReactor interface to +// allow the node to transition from block sync or state sync to consensus mode. +type WaitSyncReactor struct { + p2p.BaseReactor + Config *cfg.MempoolConfig + + waitSync atomic.Bool + waitSyncCh chan struct{} // for signaling when to start receiving and sending txs +} + +func NewWaitSyncReactor(config *cfg.MempoolConfig, waitSync bool) *WaitSyncReactor { + baseR := &WaitSyncReactor{Config: config, waitSync: atomic.Bool{}} + if waitSync { + baseR.waitSync.Store(true) + baseR.waitSyncCh = make(chan struct{}) + } + return baseR +} + +func (memR *WaitSyncReactor) EnableInOutTxs() { + memR.Logger.Info("enabling inbound and outbound transactions") + if !memR.waitSync.CompareAndSwap(true, false) { + return + } + + // Releases all the blocked broadcastTxRoutine instances. + if memR.Config.Broadcast { + close(memR.waitSyncCh) + } +} + +func (memR *WaitSyncReactor) WaitSync() bool { + return memR.waitSync.Load() +} + +func (memR *WaitSyncReactor) WaitSyncChan() chan struct{} { + return memR.waitSyncCh +} diff --git a/node/node.go b/node/node.go index 076eb00d0b0..3117179bba7 100644 --- a/node/node.go +++ b/node/node.go @@ -31,6 +31,7 @@ import ( "github.com/cometbft/cometbft/libs/log" "github.com/cometbft/cometbft/light" mempl "github.com/cometbft/cometbft/mempool" + "github.com/cometbft/cometbft/mempool/cat" "github.com/cometbft/cometbft/p2p" "github.com/cometbft/cometbft/p2p/pex" "github.com/cometbft/cometbft/proxy" @@ -67,7 +68,7 @@ type Node struct { blockStore *store.BlockStore // store the blockchain to disk pruner *sm.Pruner bcReactor p2p.Reactor // for block-syncing - mempoolReactor waitSyncP2PReactor // for gossipping transactions + mempoolReactor WaitSyncP2PReactor // for gossipping transactions mempool mempl.Mempool stateSync bool // whether the node should state sync on startup stateSyncReactor *statesync.Reactor // for hosting and restoring state sync snapshots @@ -86,7 +87,8 @@ type Node struct { pprofSrv *http.Server } -type waitSyncP2PReactor interface { +// A reactor that transitions from block sync or state sync to consensus mode. +type WaitSyncP2PReactor interface { p2p.Reactor // required by RPC service WaitSync() bool @@ -995,6 +997,10 @@ func makeNodeInfo( nodeInfo.Channels = append(nodeInfo.Channels, pex.PexChannel) } + if config.Mempool.Type == cfg.MempoolTypeCat { + nodeInfo.Channels = append(nodeInfo.Channels, cat.MempoolStateChannel) + } + lAddr := config.P2P.ExternalAddress if lAddr == "" { diff --git a/node/setup.go b/node/setup.go index da87ff2cc52..00bda132921 100644 --- a/node/setup.go +++ b/node/setup.go @@ -31,6 +31,7 @@ import ( "github.com/cometbft/cometbft/libs/log" "github.com/cometbft/cometbft/light" mempl "github.com/cometbft/cometbft/mempool" + "github.com/cometbft/cometbft/mempool/cat" "github.com/cometbft/cometbft/p2p" "github.com/cometbft/cometbft/p2p/pex" "github.com/cometbft/cometbft/privval" @@ -246,10 +247,11 @@ func createMempoolAndMempoolReactor( waitSync bool, memplMetrics *mempl.Metrics, logger log.Logger, -) (mempl.Mempool, waitSyncP2PReactor) { +) (mempl.Mempool, WaitSyncP2PReactor) { switch config.Mempool.Type { // allow empty string for backward compatibility case cfg.MempoolTypeFlood, "": + logger.Info("Using the default mempool with a flooding gossip protocol") logger = logger.With("module", "mempool") mp := mempl.NewCListMempool( config.Mempool, @@ -259,16 +261,32 @@ func createMempoolAndMempoolReactor( mempl.WithPreCheck(sm.TxPreCheck(state)), mempl.WithPostCheck(sm.TxPostCheck(state)), ) - mp.SetLogger(logger) reactor := mempl.NewReactor( config.Mempool, mp, waitSync, + logger, ) if config.Consensus.WaitForTxs() { mp.EnableTxsAvailable() } - reactor.SetLogger(logger) + + return mp, reactor + case cfg.MempoolTypeCat: + logger.Info("Using the mempool with the push-pull gossip protocol (CAT)") + logger = logger.With("module", "mempool") + mp := mempl.NewCListMempool( + config.Mempool, + proxyApp.Mempool(), + state.LastBlockHeight, + mempl.WithMetrics(memplMetrics), + mempl.WithPreCheck(sm.TxPreCheck(state)), + mempl.WithPostCheck(sm.TxPostCheck(state)), + ) + reactor := cat.NewReactor(config.Mempool, mp, waitSync, logger) + if config.Consensus.WaitForTxs() { + mp.EnableTxsAvailable() + } return mp, reactor case cfg.MempoolTypeNop: diff --git a/proto/cometbft/mempool/v1/types.proto b/proto/cometbft/mempool/v1/types.proto index 1ab4e74543c..aac204248fb 100644 --- a/proto/cometbft/mempool/v1/types.proto +++ b/proto/cometbft/mempool/v1/types.proto @@ -8,10 +8,22 @@ message Txs { repeated bytes txs = 1; } +// SeenTx contains a list of transaction keys seen by the sender. +message SeenTx { + bytes tx_key = 1; +} + +// WantTx contains a list of transaction keys wanted by the sender. +message WantTx { + bytes tx_key = 1; +} + // Message is an abstract mempool message. message Message { // Sum of all possible messages. oneof sum { - Txs txs = 1; + Txs txs = 1; + SeenTx seen_tx = 2; + WantTx want_tx = 3; } } diff --git a/proto/tendermint/mempool/types.proto b/proto/tendermint/mempool/types.proto index 7fa53ef79d8..b01d6ce36b9 100644 --- a/proto/tendermint/mempool/types.proto +++ b/proto/tendermint/mempool/types.proto @@ -5,8 +5,18 @@ message Txs { repeated bytes txs = 1; } +message SeenTx { + bytes tx_key = 1; +} + +message WantTx { + bytes tx_key = 1; +} + message Message { oneof sum { - Txs txs = 1; + Txs txs = 1; + SeenTx seen_tx = 2; + WantTx want_tx = 3; } } diff --git a/rpc/core/mempool.go b/rpc/core/mempool.go index 3c38678dfb0..46e1f81f9bd 100644 --- a/rpc/core/mempool.go +++ b/rpc/core/mempool.go @@ -24,10 +24,16 @@ func (env *Environment) BroadcastTxAsync(_ *rpctypes.Context, tx types.Tx) (*cty if env.MempoolReactor.WaitSync() { return nil, ErrEndpointClosedCatchingUp } - _, err := env.Mempool.CheckTx(tx) + reqRes, err := env.Mempool.CheckNewTx(tx) if err != nil { return nil, err } + reqRes.SetCallback(func(res *abci.Response) { + resp := reqRes.Response.GetCheckTx() + if resp.Code == abci.CodeTypeOK { + env.Mempool.InvokeNewTxReceivedOnReactor(tx.Key()) + } + }) return &ctypes.ResultBroadcastTx{Hash: tx.Hash()}, nil } @@ -40,15 +46,16 @@ func (env *Environment) BroadcastTxSync(ctx *rpctypes.Context, tx types.Tx) (*ct } resCh := make(chan *abci.CheckTxResponse, 1) - reqRes, err := env.Mempool.CheckTx(tx) + reqRes, err := env.Mempool.CheckNewTx(tx) if err != nil { return nil, err } reqRes.SetCallback(func(res *abci.Response) { - select { - case <-ctx.Context().Done(): - case resCh <- reqRes.Response.GetCheckTx(): + resp := reqRes.Response.GetCheckTx() + if resp.Code == abci.CodeTypeOK { + env.Mempool.InvokeNewTxReceivedOnReactor(tx.Key()) } + resCh <- resp }) select { case <-ctx.Context().Done(): @@ -97,16 +104,17 @@ func (env *Environment) BroadcastTxCommit(ctx *rpctypes.Context, tx types.Tx) (* // Broadcast tx and wait for CheckTx result checkTxResCh := make(chan *abci.CheckTxResponse, 1) - reqRes, err := env.Mempool.CheckTx(tx) + reqRes, err := env.Mempool.CheckNewTx(tx) if err != nil { env.Logger.Error("Error on broadcastTxCommit", "err", err) return nil, fmt.Errorf("error on broadcastTxCommit: %v", err) } reqRes.SetCallback(func(res *abci.Response) { - select { - case <-ctx.Context().Done(): - case checkTxResCh <- reqRes.Response.GetCheckTx(): + resp := reqRes.Response.GetCheckTx() + if resp.Code == abci.CodeTypeOK { + env.Mempool.InvokeNewTxReceivedOnReactor(tx.Key()) } + checkTxResCh <- resp }) select { case <-ctx.Context().Done(): diff --git a/test/e2e/Makefile b/test/e2e/Makefile index 74e6118ba04..4e44fb03a96 100644 --- a/test/e2e/Makefile +++ b/test/e2e/Makefile @@ -4,6 +4,8 @@ include ../../common.mk all: docker generator runner +fast: docker-fast generator runner + docker: @echo "Building E2E Docker image" @docker build \ @@ -45,4 +47,4 @@ lint: grammar-gen: go run github.com/goccmack/gogll/v3@latest -o pkg/grammar/grammar-auto pkg/grammar/abci_grammar.md -.PHONY: all node docker docker-debug docker-fast node-fast generator runner lint grammar-gen +.PHONY: all node docker docker-debug fast docker-fast node-fast generator runner lint grammar-gen diff --git a/test/e2e/pkg/manifest.go b/test/e2e/pkg/manifest.go index 13a2c004c91..8cdfcb2afbb 100644 --- a/test/e2e/pkg/manifest.go +++ b/test/e2e/pkg/manifest.go @@ -96,12 +96,21 @@ type Manifest struct { // Defaults to false (disabled). Prometheus bool `toml:"prometheus"` + // Set true to enable the peer-exchange reactor on all nodes. + PexReactor bool `toml:"pex"` + + // LogLevel sets the log level on all nodes. + LogLevel string `toml:"log_level"` + // Defines a minimum size for the vote extensions. VoteExtensionSize uint `toml:"vote_extension_size"` // Upper bound of sleep duration then gossipping votes and block parts PeerGossipIntraloopSleepDuration time.Duration `toml:"peer_gossip_intraloop_sleep_duration"` + // MempoolType determines the mempool type of all nodes. + MempoolType string `toml:"mempool_type"` + // Maximum number of peers to which the node gossips transactions ExperimentalMaxGossipConnectionsToPersistentPeers uint `toml:"experimental_max_gossip_connections_to_persistent_peers"` ExperimentalMaxGossipConnectionsToNonPersistentPeers uint `toml:"experimental_max_gossip_connections_to_non_persistent_peers"` diff --git a/test/e2e/pkg/testnet.go b/test/e2e/pkg/testnet.go index 5a511e4d0a0..f65addcaef3 100644 --- a/test/e2e/pkg/testnet.go +++ b/test/e2e/pkg/testnet.go @@ -19,6 +19,7 @@ import ( _ "embed" + "github.com/cometbft/cometbft/config" "github.com/cometbft/cometbft/crypto" "github.com/cometbft/cometbft/crypto/ed25519" "github.com/cometbft/cometbft/crypto/secp256k1" @@ -94,9 +95,12 @@ type Testnet struct { FinalizeBlockDelay time.Duration UpgradeVersion string Prometheus bool + PexReactor bool + LogLevel string VoteExtensionsEnableHeight int64 VoteExtensionSize uint PeerGossipIntraloopSleepDuration time.Duration + MempoolType string ExperimentalMaxGossipConnectionsToPersistentPeers uint ExperimentalMaxGossipConnectionsToNonPersistentPeers uint ABCITestsEnabled bool @@ -182,9 +186,12 @@ func NewTestnetFromManifest(manifest Manifest, file string, ifd InfrastructureDa FinalizeBlockDelay: manifest.FinalizeBlockDelay, UpgradeVersion: manifest.UpgradeVersion, Prometheus: manifest.Prometheus, + PexReactor: manifest.PexReactor, + LogLevel: manifest.LogLevel, VoteExtensionsEnableHeight: manifest.VoteExtensionsEnableHeight, VoteExtensionSize: manifest.VoteExtensionSize, PeerGossipIntraloopSleepDuration: manifest.PeerGossipIntraloopSleepDuration, + MempoolType: manifest.MempoolType, ExperimentalMaxGossipConnectionsToPersistentPeers: manifest.ExperimentalMaxGossipConnectionsToPersistentPeers, ExperimentalMaxGossipConnectionsToNonPersistentPeers: manifest.ExperimentalMaxGossipConnectionsToNonPersistentPeers, ABCITestsEnabled: manifest.ABCITestsEnabled, @@ -211,6 +218,9 @@ func NewTestnetFromManifest(manifest Manifest, file string, ifd InfrastructureDa if testnet.LoadTxSizeBytes == 0 { testnet.LoadTxSizeBytes = defaultTxSizeBytes } + if testnet.LogLevel == "" { + testnet.LogLevel = config.DefaultLogLevel + } for _, name := range sortNodeNames(manifest) { nodeManifest := manifest.Nodes[name] diff --git a/test/e2e/runner/setup.go b/test/e2e/runner/setup.go index cd78b7b4618..1d39bfaf2b9 100644 --- a/test/e2e/runner/setup.go +++ b/test/e2e/runner/setup.go @@ -183,6 +183,8 @@ func MakeConfig(node *e2e.Node) (*config.Config, error) { cfg.StateSync.DiscoveryTime = 5 * time.Second cfg.BlockSync.Version = node.BlockSyncVersion cfg.Consensus.PeerGossipIntraloopSleepDuration = node.Testnet.PeerGossipIntraloopSleepDuration + cfg.P2P.PexReactor = node.Testnet.PexReactor + cfg.LogLevel = node.Testnet.LogLevel cfg.Mempool.ExperimentalMaxGossipConnectionsToNonPersistentPeers = int(node.Testnet.ExperimentalMaxGossipConnectionsToNonPersistentPeers) cfg.Mempool.ExperimentalMaxGossipConnectionsToPersistentPeers = int(node.Testnet.ExperimentalMaxGossipConnectionsToPersistentPeers) @@ -269,6 +271,9 @@ func MakeConfig(node *e2e.Node) (*config.Config, error) { } cfg.P2P.PersistentPeers += peer.AddressP2P(true) } + if node.Testnet.MempoolType != "" { + cfg.Mempool.Type = node.Testnet.MempoolType + } if node.Testnet.DisablePexReactor { cfg.P2P.PexReactor = false } diff --git a/types/tx.go b/types/tx.go index fa6930a56e6..87cd439c62b 100644 --- a/types/tx.go +++ b/types/tx.go @@ -39,6 +39,19 @@ func (tx Tx) String() string { return fmt.Sprintf("Tx{%X}", []byte(tx)) } +func (key TxKey) String() string { + return fmt.Sprintf("TxKey{%X}", key[:]) +} + +func TxKeyFromBytes(bytes []byte) (TxKey, error) { + if len(bytes) != TxKeySize { + return TxKey{}, fmt.Errorf("incorrect tx key size. Expected %d bytes, got %d", TxKeySize, len(bytes)) + } + var key TxKey + copy(key[:], bytes) + return key, nil +} + // Txs is a slice of Tx. type Txs []Tx