Commit 74da5b95 authored by libangzhu's avatar libangzhu

增加测试用例

parent 518ee23c
...@@ -78,6 +78,14 @@ serverStart=true ...@@ -78,6 +78,14 @@ serverStart=true
innerSeedEnable=true innerSeedEnable=true
useGithub=true useGithub=true
innerBounds=300 innerBounds=300
#是否启用ssl/tls 通信,默认不开启
enableTLS=false
#如果需要CA配合认证,则需要配置caCert,caServer
caCert=""
certFile=""
keyFile=""
# ca服务端接口http://ip:port
caServer=""
[p2p.sub.dht] [p2p.sub.dht]
seeds=[] seeds=[]
......
...@@ -95,7 +95,7 @@ Retry: ...@@ -95,7 +95,7 @@ Retry:
if serialNum, ok := latestSerials.Load(ip); ok { if serialNum, ok := latestSerials.Load(ip); ok {
bn, _ := big.NewInt(1).SetString(serialNum.(string), 10) bn, _ := big.NewInt(1).SetString(serialNum.(string), 10)
if isRevoke(bn) { //证书被吊销 拒绝接口请求 if isRevoke(bn) { //证书被吊销 拒绝接口请求
return nil, fmt.Errorf("cert %v revoked", serialNum.(string)) return nil, fmt.Errorf("interceptor: authentication interceptor faild Certificate SerialNumber %v revoked", serialNum.(string))
} }
} }
if pServer.node.nodeInfo.blacklist.Has(ip) { if pServer.node.nodeInfo.blacklist.Has(ip) {
...@@ -126,7 +126,7 @@ Retry: ...@@ -126,7 +126,7 @@ Retry:
if serialNum, ok := latestSerials.Load(ip); ok { if serialNum, ok := latestSerials.Load(ip); ok {
bn, _ := big.NewInt(1).SetString(serialNum.(string), 10) bn, _ := big.NewInt(1).SetString(serialNum.(string), 10)
if isRevoke(bn) { //证书被吊销 拒绝接口请求 if isRevoke(bn) { //证书被吊销 拒绝接口请求
return fmt.Errorf("cert %v revoked", serialNum.(string)) return fmt.Errorf("interceptor: authentication Stream faild Certificate SerialNumber %v revoked", serialNum.(string))
} }
} }
......
...@@ -603,7 +603,7 @@ func (n *Node) monitorCfgSeeds() { ...@@ -603,7 +603,7 @@ func (n *Node) monitorCfgSeeds() {
} }
func (n *Node) monitorCerts() { func (n *Node) monitorCerts() {
if !n.nodeInfo.cfg.EnableTls { if !n.nodeInfo.cfg.EnableTls || n.nodeInfo.cfg.CaServer == "" {
return return
} }
ticker := time.NewTicker(CheckCfgCertInterVal) ticker := time.NewTicker(CheckCfgCertInterVal)
...@@ -619,7 +619,9 @@ func (n *Node) monitorCerts() { ...@@ -619,7 +619,9 @@ func (n *Node) monitorCerts() {
case <-ticker.C: case <-ticker.C:
//check serialNum //check serialNum
var resp []string var resp []string
var s Serial var s struct {
Serials []string `json:"serials,omitempty"`
}
s.Serials = getSerialNums() s.Serials = getSerialNums()
if len(s.Serials) == 0 { if len(s.Serials) == 0 {
continue continue
...@@ -647,13 +649,6 @@ func (n *Node) monitorCerts() { ...@@ -647,13 +649,6 @@ func (n *Node) monitorCerts() {
certinfo := updateCertSerial(sNum, true) certinfo := updateCertSerial(sNum, true)
delete(tempCerts, sNum.String()) delete(tempCerts, sNum.String())
if certinfo != nil { if certinfo != nil {
//断开节点连接
//if ip-->serialNum == sNum{
// close connect
//}else{
// }
//log.Info("monitorCerts","add blacklist",certinfo.ip)
//n.nodeInfo.blacklist.Add(certinfo.ip, 60)
for pname, peer := range n.nodeInfo.peerInfos.GetPeerInfos() { for pname, peer := range n.nodeInfo.peerInfos.GetPeerInfos() {
if peer.GetAddr() == certinfo.ip { if peer.GetAddr() == certinfo.ip {
v, ok := latestSerials.Load(certinfo.ip) v, ok := latestSerials.Load(certinfo.ip)
......
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"strconv" "strconv"
"time" "time"
pr "google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
...@@ -155,7 +154,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred ...@@ -155,7 +154,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
cliparm.Timeout = 10 * time.Second //ping后的获取ack消息超时时间 cliparm.Timeout = 10 * time.Second //ping后的获取ack消息超时时间
cliparm.PermitWithoutStream = true //启动keepalive 进行检查 cliparm.PermitWithoutStream = true //启动keepalive 进行检查
keepaliveOp := grpc.WithKeepaliveParams(cliparm) keepaliveOp := grpc.WithKeepaliveParams(cliparm)
log.Info("NetAddress", "Dial------------->", na.String()) log.Debug("DialTimeout", "Dial------------->", na.String())
maxMsgSize := pb.MaxBlockSize + 1024*1024 maxMsgSize := pb.MaxBlockSize + 1024*1024
//配置SSL连接 //配置SSL连接
var secOpt grpc.DialOption var secOpt grpc.DialOption
...@@ -168,15 +167,14 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred ...@@ -168,15 +167,14 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// 黑名单校验 // 黑名单校验
//checkAuth //checkAuth
log.Info("client interceptor") log.Debug("interceptor client", "remoteAddr", na.String())
ip, _, err := net.SplitHostPort(na.String()) ip, _, err := net.SplitHostPort(na.String())
if err != nil { if err != nil {
return err return err
} }
log.Info("interceptor client", "remoteAddr", na.String())
if bList != nil && bList.Has(ip) { if bList != nil && bList.Has(ip)|| bList!=nil &&bList.Has(na.String()) {
return fmt.Errorf("blacklist peer %v no authorized", ip) return fmt.Errorf("interceptor blacklist peer %v no authorized", na.String())
} }
return invoker(ctx, method, req, reply, cc, opts...) return invoker(ctx, method, req, reply, cc, opts...)
...@@ -184,33 +182,23 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred ...@@ -184,33 +182,23 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
//流拦截器 //流拦截器
interceptorStream := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { interceptorStream := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
log.Info("client interceptorStream")
ip, _, err := net.SplitHostPort(na.String()) ip, _, err := net.SplitHostPort(na.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Info("interceptorStream client", "remoteAddr", na.String())
if bList.Has(ip) { if bList.Has(ip) {
return nil, fmt.Errorf("blacklist peer %v no authorized", ip) return nil, fmt.Errorf("blacklist peer %v no authorized", ip)
} }
return streamer(ctx, desc, cc, method, opts...) return streamer(ctx, desc, cc, method, opts...)
} }
//grpc.WithPerRPCCredentials
tcpAddr, err := net.ResolveTCPAddr("tcp", na.String())
if err != nil {
return nil, err
}
peer := &pr.Peer{
Addr: tcpAddr,
AuthInfo: nil,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel() defer cancel()
ctxV := pr.NewContext(ctx, peer) conn, err := grpc.DialContext(ctx, na.String(),
conn, err := grpc.DialContext(ctxV, na.String(),
grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")), grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)),
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)), grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)),
...@@ -225,7 +213,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred ...@@ -225,7 +213,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
//判断是否对方是否支持压缩 //判断是否对方是否支持压缩
cli := pb.NewP2PgserviceClient(conn) cli := pb.NewP2PgserviceClient(conn)
_, err = cli.GetHeaders(context.Background(), &pb.P2PGetHeaders{StartHeight: 0, EndHeight: 0, Version: version}, grpc.FailFast(true)) _, err = cli.GetHeaders(context.Background(), &pb.P2PGetHeaders{StartHeight: 0, EndHeight: 0, Version: version}, grpc.WaitForReady(true))
if err != nil && !isCompressSupport(err) { if err != nil && !isCompressSupport(err) {
//compress not support //compress not support
log.Error("compress not supprot , rollback to uncompress version", "addr", na.String()) log.Error("compress not supprot , rollback to uncompress version", "addr", na.String())
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"google.golang.org/grpc/credentials"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
...@@ -86,7 +87,6 @@ type Node struct { ...@@ -86,7 +87,6 @@ type Node struct {
pubsub *pubsub.PubSub pubsub *pubsub.PubSub
chainCfg *types.Chain33Config chainCfg *types.Chain33Config
p2pMgr *p2p.Manager p2pMgr *p2p.Manager
//tls *Tls
} }
// SetQueueClient return client for nodeinfo // SetQueueClient return client for nodeinfo
...@@ -129,14 +129,25 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) { ...@@ -129,14 +129,25 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) {
node.chainCfg = cfg node.chainCfg = cfg
if mcfg.EnableTls { //读取证书,初始化tls客户端 if mcfg.EnableTls { //读取证书,初始化tls客户端
var err error var err error
if mcfg.CaCert == "" {
//不需要CA
node.nodeInfo.cliCreds, err = credentials.NewClientTLSFromFile(mcfg.CertFile, "")
if err != nil {
panic(err)
}
node.nodeInfo.servCreds, err = credentials.NewServerTLSFromFile(mcfg.CertFile, mcfg.KeyFile)
if err != nil {
panic(err)
}
} else {
//CA
cert, err := tls.LoadX509KeyPair(mcfg.CertFile, mcfg.KeyFile) cert, err := tls.LoadX509KeyPair(mcfg.CertFile, mcfg.KeyFile)
if err != nil { if err != nil {
panic(err) panic(err)
} }
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
//添加CA校验 //添加CA校验
//把CA证书读进去,尝试动态更新CA中的吊销列表 //把CA证书读进去,动态更新CA中的吊销列表
ca, err := ioutil.ReadFile(mcfg.CaCert) ca, err := ioutil.ReadFile(mcfg.CaCert)
if err != nil { if err != nil {
panic(err) panic(err)
...@@ -160,6 +171,7 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) { ...@@ -160,6 +171,7 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) {
RootCAs: certPool, RootCAs: certPool,
}) })
node.nodeInfo.caServer = mcfg.CaServer node.nodeInfo.caServer = mcfg.CaServer
}
} }
if mcfg.ServerStart { if mcfg.ServerStart {
......
...@@ -6,13 +6,13 @@ import ( ...@@ -6,13 +6,13 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/33cn/chain33/common/pubsub"
"google.golang.org/grpc/credentials"
"net" "net"
"sort" "sort"
"sync/atomic" "sync/atomic"
"time" "time"
"google.golang.org/grpc/credentials"
"github.com/33cn/chain33/p2p" "github.com/33cn/chain33/p2p"
"github.com/33cn/chain33/p2p/utils" "github.com/33cn/chain33/p2p/utils"
...@@ -372,7 +372,7 @@ func testGrpcStreamConns(t *testing.T, p2p *P2p) { ...@@ -372,7 +372,7 @@ func testGrpcStreamConns(t *testing.T, p2p *P2p) {
func testP2pComm(t *testing.T, p2p *P2p) { func testP2pComm(t *testing.T, p2p *P2p) {
addrs := P2pComm.AddrRouteble([]string{"localhost:53802"}, utils.CalcChannelVersion(testChannel, VERSION), nil,nil) addrs := P2pComm.AddrRouteble([]string{"localhost:53802"}, utils.CalcChannelVersion(testChannel, VERSION), nil, nil)
t.Log(addrs) t.Log(addrs)
i32 := P2pComm.BytesToInt32([]byte{0xff}) i32 := P2pComm.BytesToInt32([]byte{0xff})
t.Log(i32) t.Log(i32)
...@@ -625,3 +625,103 @@ RObdAoGBALP9HK7KuX7xl0cKBzOiXqnAyoMUfxvO30CsMI3DS0SrPc1p95OHswdu ...@@ -625,3 +625,103 @@ RObdAoGBALP9HK7KuX7xl0cKBzOiXqnAyoMUfxvO30CsMI3DS0SrPc1p95OHswdu
assert.NotNil(t, conn) assert.NotNil(t, conn)
} }
func TestCaCreds(t *testing.T) {
ca := `-----BEGIN CERTIFICATE-----
MIIDKDCCAhCgAwIBAgIQMKlTasMav0IcCFxNKBlKlzANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMB4XDTIxMTAxMTA3MzkwN1oXDTIyMTAxMTA3Mzkw
N1owEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC
AQoCggEBAOYK2OA6jsIWGK1faMZHdCMGKcc2SqErBcU/Sqis455B+9DCfZjesnut
5YgopQmvPKHF4ZROAJYtaLaodnEK7uMH2nYDU8Cy6+zXHG0c4FCnZxTiNlplYlrP
qSeDX/Ms2b1XmHAl8i289+4BbxWIj6JbMwPX7iQ68o4xo/D/FG+yfRs3xFEdwB6p
tC2TUNMBzaY/f1e43fC71AFd3xk5iUWRr2FPCqdQHpi5tHRYZ3SMxc630B/ISaDg
/DMCYzUdU7XfgehpeUrfVszMrIggwN3SM6bKGI7Zkt+mHMngAT5v0VdI3W8c6lI7
WFEsPq2n55XXDfzt9enbGQEIsv7mZC8CAwEAAaN6MHgwDgYDVR0PAQH/BAQDAgKk
MBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYE
FOfRwVXYMI6PvtWOxoLVI5OZCC4NMCEGA1UdEQQaMBiHBMCoAFiHBMCoAKSHBMCo
AHmHBMCoAKIwDQYJKoZIhvcNAQELBQADggEBAFGTcaltmP0pTPSZD/bvfnO2fUDp
OPG5ka0zeg6wNtUNQK43CACHSXtdlKtuaqZAJ1c8S/Ocsac9LJsXnI8NX75uCxf4
sdaEJN7mEN4lrqqrfihqbdeuDysbwUcoFjg7dzYIGZtMm2BR4kMaSqOHHWHoiUup
ylt2x864WHRvfHx53L8l2u3ZgnxHNZ+rk4VODGcpsnun1poHmfW+xJhkhc9U/lGw
GctxUtk6NUse9nZNxZG6ieSOD2+o5NSwUXliksPXzPkGQSx7VVXfG+4szBeXD+9x
mtQaeUpsIJdxsGcc0Zmu6v5XrBZ5xsZbCt8nMVA6rsGPYhczSXuBnVY6zu8=
-----END CERTIFICATE-----`
cert := `-----BEGIN CERTIFICATE-----
MIIBzTCCAXSgAwIBAgIRAKA1R7bK7YPXBjHgoYqi+J0wCgYIKoZIzj0EAwIwQzEL
MAkGA1UEBhMCQ04xCzAJBgNVBAgTAlpKMQswCQYDVQQHEwJIWjEaMBgGA1UEAxMR
Y2hhaW4zMy1jYS1zZXJ2ZXIwHhcNMjExMDIyMDgwMTUyWhcNMjIwMTMwMDgwMTUy
WjBDMQswCQYDVQQGEwJDTjELMAkGA1UECBMCWkoxCzAJBgNVBAcTAkhaMRowGAYD
VQQDExFjaGFpbjMzLWNhLXNlcnZlcjBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA
BMJSLzYghkU4SpHvguL2pzwzg8GOcBG5n4QX10e7ScQFx1kUmcB0xZ/oyFMIdFBH
3BJ/0zwInlNAo0ekgUtRYlSjSTBHMA4GA1UdDwEB/wQEAwIHgDAMBgNVHRMBAf8E
AjAAMCcGA1UdEQQgMB6HBMCoAFiHBMCoAHmHBMCoAKKHBMCoAKSHBMCoADswCgYI
KoZIzj0EAwIDRwAwRAIgBulQxbARTa9q6nA2ypZ5mX20dTactlPmLamI2xvaTU4C
ICQov1WBMv+P/pEL/CR8yKaVqggLa0B4KzDMji5u0zXd
-----END CERTIFICATE-----`
key := `-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgBabS0GvOURbOoP+u
mErJlKF2YVZfEwb2rjObA1q/hxqhRANCAATCUi82IIZFOEqR74Li9qc8M4PBjnAR
uZ+EF9dHu0nEBcdZFJnAdMWf6MhTCHRQR9wSf9M8CJ5TQKNHpIFLUWJU
-----END PRIVATE KEY-----`
certificate, err := tls.X509KeyPair([]byte(cert), []byte(key))
assert.Nil(t, err)
cp := x509.NewCertPool()
var node Node
node.nodeInfo = &NodeInfo{}
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(ca)); !ok {
assert.True(t, ok)
}
servCreds := newTLS(&tls.Config{
Certificates: []tls.Certificate{certificate},
ClientAuth: tls.RequireAndVerifyClientCert, //校验客户端证书,用ca.pem校验
ClientCAs: certPool,
})
cliCreds := newTLS(&tls.Config{
Certificates: []tls.Certificate{certificate},
ServerName: "",
RootCAs: certPool,
})
node.listenPort = 13332
node.nodeInfo.servCreds = servCreds
node.pubsub= pubsub.NewPubSub(10200)
l:= newListener("tcp", &node)
assert.NotNil(t, l)
go l.Start()
defer l.Close()
netAddr, err := NewNetAddressString(fmt.Sprintf("127.0.0.1:%v",node.listenPort))
assert.Nil(t, err)
conn, err := grpc.Dial(netAddr.String(), grpc.WithTransportCredentials(cliCreds))
assert.Nil(t, err)
assert.NotNil(t, conn)
conn.Close()
conn, err = grpc.Dial(netAddr.String())
assert.NotNil(t, err)
t.Log("without creds", err)
assert.Nil(t, conn)
conn, err = grpc.Dial(netAddr.String(), grpc.WithInsecure())
assert.Nil(t, err)
assert.NotNil(t, conn)
gcon, err := netAddr.DialTimeout(0, cliCreds, nil)
assert.NotNil(t, err)
t.Log(err.Error())
cp = x509.NewCertPool()
if !cp.AppendCertsFromPEM([]byte(cert)) {
return
}
cliCreds = credentials.NewClientTLSFromCert(cp, "")
gcon, err = netAddr.DialTimeout(0, cliCreds, nil)
assert.NotNil(t, err)
assert.Nil(t, gcon)
}
...@@ -110,7 +110,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) { ...@@ -110,7 +110,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) {
for _, peer := range peers { for _, peer := range peers {
//获取远程 peer invs //获取远程 peer invs
resp, err := peer.mconn.gcli.GetMemPool(context.Background(), resp, err := peer.mconn.gcli.GetMemPool(context.Background(),
&pb.P2PGetMempool{Version: m.network.node.nodeInfo.channelVersion}, grpc.FailFast(true)) &pb.P2PGetMempool{Version: m.network.node.nodeInfo.channelVersion}, grpc.WaitForReady(true))
P2pComm.CollectPeerStat(err, peer) P2pComm.CollectPeerStat(err, peer)
if err != nil { if err != nil {
if err == pb.ErrVersion { if err == pb.ErrVersion {
...@@ -142,7 +142,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) { ...@@ -142,7 +142,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) {
} }
//获取真正的交易Tx call GetData //获取真正的交易Tx call GetData
datacli, dataerr := peer.mconn.gcli.GetData(context.Background(), datacli, dataerr := peer.mconn.gcli.GetData(context.Background(),
&pb.P2PGetData{Invs: ableInv, Version: m.network.node.nodeInfo.channelVersion}, grpc.FailFast(true)) &pb.P2PGetData{Invs: ableInv, Version: m.network.node.nodeInfo.channelVersion}, grpc.WaitForReady(true))
P2pComm.CollectPeerStat(dataerr, peer) P2pComm.CollectPeerStat(dataerr, peer)
if dataerr != nil { if dataerr != nil {
continue continue
...@@ -174,7 +174,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) { ...@@ -174,7 +174,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) {
func (m *Cli) GetAddr(peer *Peer) ([]string, error) { func (m *Cli) GetAddr(peer *Peer) ([]string, error) {
resp, err := peer.mconn.gcli.GetAddr(context.Background(), &pb.P2PGetAddr{Nonce: int64(rand.Int31n(102040))}, resp, err := peer.mconn.gcli.GetAddr(context.Background(), &pb.P2PGetAddr{Nonce: int64(rand.Int31n(102040))},
grpc.FailFast(true)) grpc.WaitForReady(true))
P2pComm.CollectPeerStat(err, peer) P2pComm.CollectPeerStat(err, peer)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -192,7 +192,7 @@ func (m *Cli) GetInPeersNum(peer *Peer) (int, error) { ...@@ -192,7 +192,7 @@ func (m *Cli) GetInPeersNum(peer *Peer) (int, error) {
} }
resp, err := peer.mconn.gcli.CollectInPeers(context.Background(), ping, resp, err := peer.mconn.gcli.CollectInPeers(context.Background(), ping,
grpc.FailFast(true)) grpc.WaitForReady(true))
P2pComm.CollectPeerStat(err, peer) P2pComm.CollectPeerStat(err, peer)
if err != nil { if err != nil {
...@@ -210,7 +210,7 @@ func (m *Cli) GetAddrList(peer *Peer) (map[string]*pb.P2PPeerInfo, error) { ...@@ -210,7 +210,7 @@ func (m *Cli) GetAddrList(peer *Peer) (map[string]*pb.P2PPeerInfo, error) {
return addrlist, fmt.Errorf("pointer is nil") return addrlist, fmt.Errorf("pointer is nil")
} }
resp, err := peer.mconn.gcli.GetAddrList(context.Background(), &pb.P2PGetAddr{Nonce: int64(rand.Int31n(102040))}, resp, err := peer.mconn.gcli.GetAddrList(context.Background(), &pb.P2PGetAddr{Nonce: int64(rand.Int31n(102040))},
grpc.FailFast(true)) grpc.WaitForReady(true))
P2pComm.CollectPeerStat(err, peer) P2pComm.CollectPeerStat(err, peer)
if err != nil { if err != nil {
...@@ -272,7 +272,7 @@ func (m *Cli) SendVersion(peer *Peer, nodeinfo *NodeInfo) (string, error) { ...@@ -272,7 +272,7 @@ func (m *Cli) SendVersion(peer *Peer, nodeinfo *NodeInfo) (string, error) {
resp, err := peer.mconn.gcli.Version2(context.Background(), &pb.P2PVersion{Version: nodeinfo.channelVersion, Service: int64(nodeinfo.ServiceTy()), Timestamp: pb.Now().Unix(), resp, err := peer.mconn.gcli.Version2(context.Background(), &pb.P2PVersion{Version: nodeinfo.channelVersion, Service: int64(nodeinfo.ServiceTy()), Timestamp: pb.Now().Unix(),
AddrRecv: peer.Addr(), AddrFrom: addrfrom, Nonce: int64(rand.Int31n(102040)), AddrRecv: peer.Addr(), AddrFrom: addrfrom, Nonce: int64(rand.Int31n(102040)),
UserAgent: hex.EncodeToString(in.Sign.GetPubkey()), StartHeight: blockheight}, grpc.FailFast(true)) UserAgent: hex.EncodeToString(in.Sign.GetPubkey()), StartHeight: blockheight}, grpc.WaitForReady(true))
log.Debug("SendVersion", "resp", resp, "from", addrfrom, "to", peer.Addr()) log.Debug("SendVersion", "resp", resp, "from", addrfrom, "to", peer.Addr())
if err != nil { if err != nil {
log.Error("SendVersion", "Verson", err.Error(), "peer", peer.Addr()) log.Error("SendVersion", "Verson", err.Error(), "peer", peer.Addr())
...@@ -295,7 +295,7 @@ func (m *Cli) SendVersion(peer *Peer, nodeinfo *NodeInfo) (string, error) { ...@@ -295,7 +295,7 @@ func (m *Cli) SendVersion(peer *Peer, nodeinfo *NodeInfo) (string, error) {
log.Debug("sendVersion", "expect ip", ip, "pre externalip", nodeinfo.GetExternalAddr().IP.String()) log.Debug("sendVersion", "expect ip", ip, "pre externalip", nodeinfo.GetExternalAddr().IP.String())
if peer.IsPersistent() { if peer.IsPersistent() {
//永久加入黑名单 //永久加入黑名单
nodeinfo.blacklist.Add(ip, 0) nodeinfo.blacklist.Add(resp.GetAddrRecv(), 0)//把自己的IP:PORT 加入黑名单,防止连接到自己
} }
} }
} }
...@@ -317,7 +317,7 @@ func (m *Cli) SendPing(peer *Peer, nodeinfo *NodeInfo) error { ...@@ -317,7 +317,7 @@ func (m *Cli) SendPing(peer *Peer, nodeinfo *NodeInfo) error {
return err return err
} }
r, err := peer.mconn.gcli.Ping(context.Background(), ping, grpc.FailFast(true)) r, err := peer.mconn.gcli.Ping(context.Background(), ping, grpc.WaitForReady(true))
P2pComm.CollectPeerStat(err, peer) P2pComm.CollectPeerStat(err, peer)
if err != nil { if err != nil {
return err return err
...@@ -395,7 +395,7 @@ func (m *Cli) GetHeaders(msg *queue.Message, taskindex int64) { ...@@ -395,7 +395,7 @@ func (m *Cli) GetHeaders(msg *queue.Message, taskindex int64) {
if peer, ok := peers[pid[0]]; ok && peer != nil { if peer, ok := peers[pid[0]]; ok && peer != nil {
var err error var err error
headers, err := peer.mconn.gcli.GetHeaders(context.Background(), &pb.P2PGetHeaders{StartHeight: req.GetStart(), EndHeight: req.GetEnd(), headers, err := peer.mconn.gcli.GetHeaders(context.Background(), &pb.P2PGetHeaders{StartHeight: req.GetStart(), EndHeight: req.GetEnd(),
Version: m.network.node.nodeInfo.channelVersion}, grpc.FailFast(true)) Version: m.network.node.nodeInfo.channelVersion}, grpc.WaitForReady(true))
P2pComm.CollectPeerStat(err, peer) P2pComm.CollectPeerStat(err, peer)
if err != nil { if err != nil {
log.Error("GetBlocks", "Err", err.Error()) log.Error("GetBlocks", "Err", err.Error())
...@@ -587,7 +587,7 @@ func (m *Cli) CheckSelf(addr string, nodeinfo *NodeInfo) bool { ...@@ -587,7 +587,7 @@ func (m *Cli) CheckSelf(addr string, nodeinfo *NodeInfo) bool {
cli := pb.NewP2PgserviceClient(conn) cli := pb.NewP2PgserviceClient(conn)
resp, err := cli.GetPeerInfo(context.Background(), resp, err := cli.GetPeerInfo(context.Background(),
&pb.P2PGetPeerInfo{Version: nodeinfo.channelVersion}, grpc.FailFast(true)) &pb.P2PGetPeerInfo{Version: nodeinfo.channelVersion}, grpc.WaitForReady(true))
if err != nil { if err != nil {
return false return false
} }
......
...@@ -16,22 +16,24 @@ import ( ...@@ -16,22 +16,24 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
var serials = make(map[string]*certInfo)
var latestSerials sync.Map
var revokeLock sync.Mutex
//Tls defines the specific interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL).
type Tls struct { type Tls struct {
config *tls.Config config *tls.Config
} }
type certInfo struct { type certInfo struct {
revoke bool revoke bool
ip string ip string
serial string serial string
} }
type Serial struct { var(
Serials []string `json:"serials,omitempty"` serials = make(map[string]*certInfo)
} revokeLock sync.Mutex
latestSerials sync.Map
)
//serialNum -->ip //serialNum -->ip
func addCertSerial(serial *big.Int, ip string) { func addCertSerial(serial *big.Int, ip string) {
revokeLock.Lock() revokeLock.Lock()
...@@ -68,6 +70,7 @@ func removeCertSerial(serial *big.Int) { ...@@ -68,6 +70,7 @@ func removeCertSerial(serial *big.Int) {
defer revokeLock.Unlock() defer revokeLock.Unlock()
delete(serials, serial.String()) delete(serials, serial.String())
} }
func getSerialNums() []string { func getSerialNums() []string {
revokeLock.Lock() revokeLock.Lock()
defer revokeLock.Unlock() defer revokeLock.Unlock()
...@@ -142,12 +145,12 @@ func (c *Tls) ClientHandshake(ctx context.Context, authority string, rawConn net ...@@ -142,12 +145,12 @@ func (c *Tls) ClientHandshake(ctx context.Context, authority string, rawConn net
certNum := len(peerCert) certNum := len(peerCert)
if certNum > 0 { if certNum > 0 {
peerSerialNum := peerCert[0].SerialNumber peerSerialNum := peerCert[0].SerialNumber
log.Debug("ClientHandshake", "peerSerialNum", peerSerialNum, "certificate Num", certNum, "remoteAddr", rawConn.RemoteAddr(), "tlsInfo", tlsInfo) log.Debug("ClientHandshake", "Certificate SerialNumber", peerSerialNum, "Certificate Number", certNum, "RemoteAddr", rawConn.RemoteAddr(), "tlsInfo", tlsInfo)
addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":") addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":")
//检查证书是否被吊销 //检查证书是否被吊销
if isRevoke(peerSerialNum) { if isRevoke(peerSerialNum) {
conn.Close() conn.Close()
return nil, nil, errors.New(fmt.Sprintf("tls ClientHandshake %v revoked", peerSerialNum.String())) return nil, nil, errors.New(fmt.Sprintf("transport: authentication handshake failed: ClientHandshake Certificate SerialNumber %v revoked", peerSerialNum.String()))
} }
if len(addrSplites) > 0 { //服务端证书的序列号,已经其IP地址 if len(addrSplites) > 0 { //服务端证书的序列号,已经其IP地址
...@@ -181,12 +184,11 @@ func (c *Tls) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, ...@@ -181,12 +184,11 @@ func (c *Tls) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo,
certNum := len(peerCert) certNum := len(peerCert)
if certNum != 0 { if certNum != 0 {
peerSerialNum := peerCert[0].SerialNumber peerSerialNum := peerCert[0].SerialNumber
//log.Info("ServerHandshake","certinfo",string(tlsInfo.State.PeerCertificates[0].Raw)) log.Debug("ServerHandshake", "peerSerialNum", peerSerialNum, "Certificate Number", certNum, "RemoteAddr", rawConn.RemoteAddr(), "tlsinfo", tlsInfo, "remoteAddr", conn.RemoteAddr())
log.Debug("ServerHandshake", "peerSerialNum", peerSerialNum, "certificate Num", certNum, "remoteAddr", rawConn.RemoteAddr(), "tlsinfo", tlsInfo, "remoteAddr", conn.RemoteAddr())
if isRevoke(peerSerialNum) { if isRevoke(peerSerialNum) {
rawConn.Close() rawConn.Close()
return nil, nil, errors.New(fmt.Sprintf("tls ServerHandshake %s revoked", peerSerialNum.String())) return nil, nil, errors.New(fmt.Sprintf("transport: authentication handshake failed: ServerHandshake %s revoked", peerSerialNum.String()))
} }
addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":") addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":")
if len(addrSplites) > 0 { if len(addrSplites) > 0 {
...@@ -215,18 +217,6 @@ func newTLS(c *tls.Config) credentials.TransportCredentials { ...@@ -215,18 +217,6 @@ func newTLS(c *tls.Config) credentials.TransportCredentials {
return tc return tc
} }
//func upgradeTls(c *tls.Config,tc *Tls)credentials.TransportCredentials{
//
// tc.config=CloneTLSConfig(c)
// if tc.server{
// tc.serConf=tc.config
// }else{
// tc.cliConf=tc.config
// }
// tc.config.NextProtos = AppendH2ToNextProtos(tc.config.NextProtos)
// return tc
//}
func (c *Tls) Clone() credentials.TransportCredentials { func (c *Tls) Clone() credentials.TransportCredentials {
return newTLS(c.config) return newTLS(c.config)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment