Unverified Commit 958b08d8 authored by 33cn's avatar 33cn Committed by GitHub

Merge pull request #1074 from libangzhu/gossip-tls

增加ca认证下的tls通信
parents 02423b67 12701bfb
......@@ -78,6 +78,14 @@ serverStart=true
innerSeedEnable=true
useGithub=true
innerBounds=300
#是否启用ssl/tls 通信,默认不开启
enableTLS=false
#如果需要CA配合认证,则需要配置caCert,caServer
caCert=""
certFile=""
keyFile=""
# ca服务端接口http://ip:port
caServer=""
[p2p.sub.dht]
seeds=[]
......
......@@ -58,7 +58,7 @@ func (Comm) ParaseNetAddr(addr string) (string, int64, error) {
}
// AddrRouteble address router ,return enbale address
func (Comm) AddrRouteble(addrs []string, version int32, creds credentials.TransportCredentials) []string {
func (Comm) AddrRouteble(addrs []string, version int32, creds credentials.TransportCredentials, blist *BlackList) []string {
var enableAddrs []string
for _, addr := range addrs {
......@@ -67,7 +67,7 @@ func (Comm) AddrRouteble(addrs []string, version int32, creds credentials.Transp
log.Error("AddrRouteble", "NewNetAddressString", err.Error())
continue
}
conn, err := netaddr.DialTimeout(version, creds)
conn, err := netaddr.DialTimeout(version, creds, blist)
if err != nil {
//log.Error("AddrRouteble", "DialTimeout", err.Error())
continue
......@@ -110,15 +110,18 @@ func (c Comm) GetLocalAddr() string {
func (c Comm) dialPeerWithAddress(addr *NetAddress, persistent bool, node *Node) (*Peer, error) {
log.Debug("dialPeerWithAddress")
conn, err := addr.DialTimeout(node.nodeInfo.channelVersion, node.nodeInfo.cliCreds)
conn, err := addr.DialTimeout(node.nodeInfo.channelVersion, node.nodeInfo.cliCreds, node.nodeInfo.blacklist)
if err != nil {
log.Error("dialPeerWithAddress", "DialTimeoutErr", err.Error())
return nil, err
}
peer, err := c.newPeerFromConn(conn, addr, node)
if err != nil {
log.Error("dialPeerWithAddress", "newPeerFromConn", err)
err = conn.Close()
return nil, err
}
peer.SetAddr(addr)
......@@ -151,7 +154,6 @@ func (c Comm) dialPeerWithAddress(addr *NetAddress, persistent bool, node *Node)
peer.Close()
return nil, errors.New(fmt.Sprintf("duplicate connect %v", resp.UserAgent))
}
node.peerStore.Store(addr.String(), resp.UserAgent)
peer.SetPeerName(resp.UserAgent)
return peer, nil
......
......@@ -25,6 +25,7 @@ var (
CheckActivePeersInterVal = 5 * time.Second
CheckBlackListInterVal = 30 * time.Second
CheckCfgSeedsInterVal = 1 * time.Minute
CheckCfgCertInterVal = 30 * time.Second
)
const (
......
......@@ -91,6 +91,11 @@ Retry:
if err != nil {
return nil, err
}
if serialNum, ok := latestSerials.Load(ip); ok {
if isRevoke(serialNum.(string)) { //证书被吊销 拒绝接口请求
return nil, fmt.Errorf("interceptor: authentication interceptor faild Certificate SerialNumber %v revoked", serialNum.(string))
}
}
if pServer.node.nodeInfo.blacklist.Has(ip) {
return nil, fmt.Errorf("blacklist %v no authorized", ip)
}
......@@ -116,6 +121,12 @@ Retry:
if err != nil {
return err
}
if serialNum, ok := latestSerials.Load(ip); ok {
if isRevoke(serialNum.(string)) { //证书被吊销 拒绝接口请求
return fmt.Errorf("interceptor: authentication Stream faild Certificate SerialNumber %v revoked", serialNum.(string))
}
}
if pServer.node.nodeInfo.blacklist.Has(ip) {
return fmt.Errorf("blacklist %v no authorized", ip)
}
......@@ -146,12 +157,12 @@ Retry:
opts = append(opts, msgRecvOp, msgSendOp, grpc.KeepaliveEnforcementPolicy(kaep), keepOp, maxStreams, StatsOp)
if node.nodeInfo.servCreds != nil {
opts = append(opts, grpc.Creds(node.nodeInfo.servCreds))
}
dl.server = grpc.NewServer(opts...)
dl.p2pserver = pServer
pb.RegisterP2PgserviceServer(dl.server, pServer)
return dl
}
......
......@@ -11,8 +11,9 @@ import (
"strings"
"time"
"github.com/33cn/chain33/p2p/utils"
"github.com/33cn/chain33/rpc/jsonclient"
"github.com/33cn/chain33/p2p/utils"
"github.com/33cn/chain33/types"
)
......@@ -599,3 +600,63 @@ func (n *Node) monitorCfgSeeds() {
}
}
func (n *Node) monitorCerts() {
if !n.nodeInfo.cfg.EnableTls || n.nodeInfo.cfg.CaServer == "" {
return
}
ticker := time.NewTicker(CheckCfgCertInterVal)
defer ticker.Stop()
jcli, err := jsonclient.New("chain33-ca-server", n.nodeInfo.caServer, false)
if err != nil {
log.Error("monitorCerts", "rpc call err", err)
return
}
for {
select {
case <-ticker.C:
//check serialNum
var resp []string
var s struct {
Serials []string `json:"serials,omitempty"`
}
s.Serials = getSerialNums()
if len(s.Serials) == 0 {
continue
}
log.Debug("check cert serialNum++++++", "certNum.", len(s.Serials))
err = jcli.Call("Validate", s, &resp)
if err != nil {
log.Error("monitorCerts", "rpc call err", err)
continue
}
log.Debug("monitorCerts", "resp", resp)
tempCerts := getSerials()
for _, serialNum := range resp {
//设置证书序列号状态
certinfo := updateCertSerial(serialNum, true)
delete(tempCerts, serialNum)
for pname, peer := range n.nodeInfo.peerInfos.GetPeerInfos() {
if peer.GetAddr() == certinfo.ip {
v, ok := latestSerials.Load(certinfo.ip)
if ok && v.(string) == serialNum {
n.remove(pname) //断开已经连接的节点
}
}
}
}
log.Debug("monitorCert", "tempCerts", tempCerts)
//处理解除吊销的节点
for serialNum, info := range tempCerts {
if info.revoke {
// 被撤销的证书恢复正常
updateCertSerial(serialNum, !info.revoke)
}
}
}
}
}
......@@ -28,7 +28,7 @@ func TestNetAddress(t *testing.T) {
}
func TestAddrRouteble(t *testing.T) {
resp := P2pComm.AddrRouteble([]string{"114.55.101.159:13802"}, utils.CalcChannelVersion(119, VERSION), nil)
resp := P2pComm.AddrRouteble([]string{"114.55.101.159:13802"}, utils.CalcChannelVersion(119, VERSION), nil, nil)
t.Log(resp)
}
......
......@@ -11,6 +11,8 @@ import (
"strconv"
"time"
"google.golang.org/grpc/status"
"google.golang.org/grpc/credentials"
pb "github.com/33cn/chain33/types"
......@@ -136,14 +138,14 @@ func (na *NetAddress) Copy() *NetAddress {
// DialTimeout calls net.DialTimeout on the address.
func isCompressSupport(err error) bool {
var errstr = `grpc: Decompressor is not installed for grpc-encoding "gzip"`
if grpc.Code(err) == codes.Unimplemented && grpc.ErrorDesc(err) == errstr {
if status.Code(err) == codes.Unimplemented && status.Convert(err).Message() == errstr {
return false
}
return true
}
// DialTimeout dial timeout
func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCredentials) (*grpc.ClientConn, error) {
func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCredentials, bList *BlackList) (*grpc.ClientConn, error) {
ch := make(chan grpc.ServiceConfig, 1)
ch <- P2pComm.GrpcConfig()
......@@ -152,8 +154,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
cliparm.Timeout = 10 * time.Second //ping后的获取ack消息超时时间
cliparm.PermitWithoutStream = true //启动keepalive 进行检查
keepaliveOp := grpc.WithKeepaliveParams(cliparm)
timeoutOp := grpc.WithTimeout(time.Second * 3)
log.Debug("NetAddress", "Dial", na.String())
log.Debug("DialTimeout", "Dial------------->", na.String())
maxMsgSize := pb.MaxBlockSize + 1024*1024
//配置SSL连接
var secOpt grpc.DialOption
......@@ -162,21 +163,56 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
} else {
secOpt = grpc.WithTransportCredentials(creds)
}
//grpc.WithPerRPCCredentials
conn, err := grpc.Dial(na.String(),
//接口拦截器
interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// 黑名单校验
//checkAuth
log.Debug("interceptor client", "remoteAddr", na.String())
ip, _, err := net.SplitHostPort(na.String())
if err != nil {
return err
}
if bList != nil && bList.Has(ip) || bList != nil && bList.Has(na.String()) {
return fmt.Errorf("interceptor blacklist peer %v no authorized", na.String())
}
return invoker(ctx, method, req, reply, cc, opts...)
}
//流拦截器
interceptorStream := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
ip, _, err := net.SplitHostPort(na.String())
if err != nil {
return nil, err
}
if bList.Has(ip) {
return nil, fmt.Errorf("blacklist peer %v no authorized", ip)
}
return streamer(ctx, desc, cc, method, opts...)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
conn, err := grpc.DialContext(ctx, na.String(),
grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)),
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)),
grpc.WithServiceConfig(ch), keepaliveOp, timeoutOp, secOpt)
grpc.WithServiceConfig(ch), keepaliveOp, secOpt,
grpc.WithUnaryInterceptor(interceptor), grpc.WithStreamInterceptor(interceptorStream))
if err != nil {
log.Debug("grpc DialCon", "did not connect", err, "addr", na.String())
log.Error("grpc DialCon", "did not connect", err, "addr", na.String())
return nil, err
}
//p2p version check 通过版本协议,获取通信session
//判断是否对方是否支持压缩
cli := pb.NewP2PgserviceClient(conn)
_, err = cli.GetHeaders(context.Background(), &pb.P2PGetHeaders{StartHeight: 0, EndHeight: 0, Version: version}, grpc.FailFast(true))
_, err = cli.GetHeaders(context.Background(), &pb.P2PGetHeaders{StartHeight: 0, EndHeight: 0, Version: version}, grpc.WaitForReady(false))
if err != nil && !isCompressSupport(err) {
//compress not support
log.Error("compress not supprot , rollback to uncompress version", "addr", na.String())
......@@ -187,7 +223,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
ch2 := make(chan grpc.ServiceConfig, 1)
ch2 <- P2pComm.GrpcConfig()
log.Debug("NetAddress", "Dial with unCompressor", na.String())
conn, err = grpc.Dial(na.String(), secOpt, grpc.WithServiceConfig(ch2), keepaliveOp, timeoutOp)
conn, err = grpc.DialContext(ctx, na.String(), secOpt, grpc.WithServiceConfig(ch2), keepaliveOp)
}
......
......@@ -5,7 +5,10 @@
package gossip
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"math/rand"
"google.golang.org/grpc/credentials"
......@@ -102,6 +105,7 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) {
pubsub: pubsub.NewPubSub(10200),
p2pMgr: mgr,
}
//node.tls = &Tls{serials:make(map[*big.Int]*certInfo)}
node.listenPort = 13802
if mcfg.Port != 0 && mcfg.Port <= 65535 && mcfg.Port > 1024 {
node.listenPort = int(mcfg.Port)
......@@ -126,18 +130,55 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) {
node.chainCfg = cfg
if mcfg.EnableTls { //读取证书,初始化tls客户端
var err error
node.nodeInfo.cliCreds, err = credentials.NewClientTLSFromFile(cfg.GetModuleConfig().RPC.CertFile, "")
if err != nil {
panic(err)
}
node.nodeInfo.servCreds, err = credentials.NewServerTLSFromFile(cfg.GetModuleConfig().RPC.CertFile, cfg.GetModuleConfig().RPC.KeyFile)
if err != nil {
panic(err)
if mcfg.CaCert == "" {
//不需要CA
node.nodeInfo.cliCreds, err = credentials.NewClientTLSFromFile(mcfg.CertFile, "")
if err != nil {
panic(fmt.Sprintf("NewClientTLSFromFile panic:%v", err.Error()))
}
node.nodeInfo.servCreds, err = credentials.NewServerTLSFromFile(mcfg.CertFile, mcfg.KeyFile)
if err != nil {
panic(fmt.Sprintf("NewServerTLSFromFile panic:%v", err.Error()))
}
} else {
//CA
cert, err := tls.LoadX509KeyPair(mcfg.CertFile, mcfg.KeyFile)
if err != nil {
panic(fmt.Sprintf("LoadX509KeyPair panic:%v", err.Error()))
}
certPool := x509.NewCertPool()
//添加CA校验
//把CA证书读进去,动态更新CA中的吊销列表
ca, err := ioutil.ReadFile(mcfg.CaCert)
if err != nil {
panic(fmt.Sprintf("readFile ca panic:%v", err.Error()))
}
if ok := certPool.AppendCertsFromPEM(ca); !ok {
panic("certPool.AppendCertsFromPEM err")
}
node.nodeInfo.servCreds = newTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert, //校验客户端证书,用ca.pem校验
ClientCAs: certPool,
})
// 构建基于 TLS 的 TransportCredentials 选项
// 在 Client 请求 Server 端时,Client 端会使用根证书和 ServerName 去对 Server 端进行校验
node.nodeInfo.cliCreds = newTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
ServerName: "",
RootCAs: certPool,
})
node.nodeInfo.caServer = mcfg.CaServer
}
}
if mcfg.ServerStart {
node.server = newListener(protocol, node)
}
return node, nil
}
......@@ -171,7 +212,7 @@ func (n *Node) doNat() {
}
testExaddr := fmt.Sprintf("%v:%v", n.nodeInfo.GetExternalAddr().IP.String(), n.listenPort)
log.Info("TestNetAddr", "testExaddr", testExaddr)
if len(P2pComm.AddrRouteble([]string{testExaddr}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds)) != 0 {
if len(P2pComm.AddrRouteble([]string{testExaddr}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds, n.nodeInfo.blacklist)) != 0 {
log.Info("node outside")
n.nodeInfo.SetNetSide(true)
if netexaddr, err := NewNetAddressString(testExaddr); err == nil {
......@@ -374,6 +415,7 @@ func (n *Node) monitor() {
go n.monitorFilter()
go n.monitorPeers()
go n.nodeReBalance()
go n.monitorCerts()
}
func (n *Node) needMore() bool {
......@@ -454,7 +496,7 @@ func (n *Node) natMapPort() {
time.Sleep(time.Second)
}
var err error
if len(P2pComm.AddrRouteble([]string{n.nodeInfo.GetExternalAddr().String()}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds)) != 0 { //判断能否连通要映射的端口
if len(P2pComm.AddrRouteble([]string{n.nodeInfo.GetExternalAddr().String()}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds, n.nodeInfo.blacklist)) != 0 { //判断能否连通要映射的端口
log.Info("natMapPort", "addr", "routeble")
p2pcli := NewNormalP2PCli() //检查要映射的IP地址是否已经被映射成功
ok := p2pcli.CheckSelf(n.nodeInfo.GetExternalAddr().String(), n.nodeInfo)
......
......@@ -36,6 +36,7 @@ type NodeInfo struct {
channelVersion int32
cliCreds credentials.TransportCredentials
servCreds credentials.TransportCredentials
caServer string
}
// NewNodeInfo new a node object
......
......@@ -60,7 +60,12 @@ type subConfig struct {
//触发区块轻广播最小大小, KB
MinLtBlockSize int32 `protobuf:"varint,12,opt,name=minLtBlockSize" json:"minLtBlockSize,omitempty"`
//是否使用证书进行节点之间的通信,true 使用证书通信,读取rpc配置项下的证书文件
EnableTls bool `protobuf:"varint,13,opt,name=enableTls" json:"enableTls,omitempty"`
EnableTls bool `protobuf:"varint,13,opt,name=enableTls" json:"enableTls,omitempty"`
CaCert string `json:"caCert,omitempty"`
CaServer string `json:"caServer,omitempty"`
CertFile string `json:"certFile,omitempty"`
// 私钥文件
KeyFile string `json:"keyFile,omitempty"`
}
// P2p interface
......
......@@ -11,6 +11,7 @@ import (
"sync/atomic"
"time"
"github.com/33cn/chain33/common/pubsub"
"google.golang.org/grpc/credentials"
"github.com/33cn/chain33/p2p"
......@@ -372,7 +373,7 @@ func testGrpcStreamConns(t *testing.T, p2p *P2p) {
func testP2pComm(t *testing.T, p2p *P2p) {
addrs := P2pComm.AddrRouteble([]string{"localhost:53802"}, utils.CalcChannelVersion(testChannel, VERSION), nil)
addrs := P2pComm.AddrRouteble([]string{"localhost:53802"}, utils.CalcChannelVersion(testChannel, VERSION), nil, nil)
t.Log(addrs)
i32 := P2pComm.BytesToInt32([]byte{0xff})
t.Log(i32)
......@@ -625,3 +626,102 @@ RObdAoGBALP9HK7KuX7xl0cKBzOiXqnAyoMUfxvO30CsMI3DS0SrPc1p95OHswdu
assert.NotNil(t, conn)
}
func TestCaCreds(t *testing.T) {
ca := `-----BEGIN CERTIFICATE-----
MIIDKDCCAhCgAwIBAgIQMKlTasMav0IcCFxNKBlKlzANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMB4XDTIxMTAxMTA3MzkwN1oXDTIyMTAxMTA3Mzkw
N1owEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC
AQoCggEBAOYK2OA6jsIWGK1faMZHdCMGKcc2SqErBcU/Sqis455B+9DCfZjesnut
5YgopQmvPKHF4ZROAJYtaLaodnEK7uMH2nYDU8Cy6+zXHG0c4FCnZxTiNlplYlrP
qSeDX/Ms2b1XmHAl8i289+4BbxWIj6JbMwPX7iQ68o4xo/D/FG+yfRs3xFEdwB6p
tC2TUNMBzaY/f1e43fC71AFd3xk5iUWRr2FPCqdQHpi5tHRYZ3SMxc630B/ISaDg
/DMCYzUdU7XfgehpeUrfVszMrIggwN3SM6bKGI7Zkt+mHMngAT5v0VdI3W8c6lI7
WFEsPq2n55XXDfzt9enbGQEIsv7mZC8CAwEAAaN6MHgwDgYDVR0PAQH/BAQDAgKk
MBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYE
FOfRwVXYMI6PvtWOxoLVI5OZCC4NMCEGA1UdEQQaMBiHBMCoAFiHBMCoAKSHBMCo
AHmHBMCoAKIwDQYJKoZIhvcNAQELBQADggEBAFGTcaltmP0pTPSZD/bvfnO2fUDp
OPG5ka0zeg6wNtUNQK43CACHSXtdlKtuaqZAJ1c8S/Ocsac9LJsXnI8NX75uCxf4
sdaEJN7mEN4lrqqrfihqbdeuDysbwUcoFjg7dzYIGZtMm2BR4kMaSqOHHWHoiUup
ylt2x864WHRvfHx53L8l2u3ZgnxHNZ+rk4VODGcpsnun1poHmfW+xJhkhc9U/lGw
GctxUtk6NUse9nZNxZG6ieSOD2+o5NSwUXliksPXzPkGQSx7VVXfG+4szBeXD+9x
mtQaeUpsIJdxsGcc0Zmu6v5XrBZ5xsZbCt8nMVA6rsGPYhczSXuBnVY6zu8=
-----END CERTIFICATE-----`
cert := `-----BEGIN CERTIFICATE-----
MIIBzTCCAXSgAwIBAgIRAKA1R7bK7YPXBjHgoYqi+J0wCgYIKoZIzj0EAwIwQzEL
MAkGA1UEBhMCQ04xCzAJBgNVBAgTAlpKMQswCQYDVQQHEwJIWjEaMBgGA1UEAxMR
Y2hhaW4zMy1jYS1zZXJ2ZXIwHhcNMjExMDIyMDgwMTUyWhcNMjIwMTMwMDgwMTUy
WjBDMQswCQYDVQQGEwJDTjELMAkGA1UECBMCWkoxCzAJBgNVBAcTAkhaMRowGAYD
VQQDExFjaGFpbjMzLWNhLXNlcnZlcjBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA
BMJSLzYghkU4SpHvguL2pzwzg8GOcBG5n4QX10e7ScQFx1kUmcB0xZ/oyFMIdFBH
3BJ/0zwInlNAo0ekgUtRYlSjSTBHMA4GA1UdDwEB/wQEAwIHgDAMBgNVHRMBAf8E
AjAAMCcGA1UdEQQgMB6HBMCoAFiHBMCoAHmHBMCoAKKHBMCoAKSHBMCoADswCgYI
KoZIzj0EAwIDRwAwRAIgBulQxbARTa9q6nA2ypZ5mX20dTactlPmLamI2xvaTU4C
ICQov1WBMv+P/pEL/CR8yKaVqggLa0B4KzDMji5u0zXd
-----END CERTIFICATE-----`
key := `-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgBabS0GvOURbOoP+u
mErJlKF2YVZfEwb2rjObA1q/hxqhRANCAATCUi82IIZFOEqR74Li9qc8M4PBjnAR
uZ+EF9dHu0nEBcdZFJnAdMWf6MhTCHRQR9wSf9M8CJ5TQKNHpIFLUWJU
-----END PRIVATE KEY-----`
certificate, err := tls.X509KeyPair([]byte(cert), []byte(key))
assert.Nil(t, err)
cp := x509.NewCertPool()
var node Node
node.nodeInfo = &NodeInfo{}
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(ca)); !ok {
assert.True(t, ok)
}
servCreds := newTLS(&tls.Config{
Certificates: []tls.Certificate{certificate},
ClientAuth: tls.RequireAndVerifyClientCert, //校验客户端证书,用ca.pem校验
ClientCAs: certPool,
})
cliCreds := newTLS(&tls.Config{
Certificates: []tls.Certificate{certificate},
ServerName: "",
RootCAs: certPool,
})
node.listenPort = 13332
node.nodeInfo.servCreds = servCreds
node.pubsub = pubsub.NewPubSub(10200)
l := newListener("tcp", &node)
assert.NotNil(t, l)
go l.Start()
defer l.Close()
netAddr, err := NewNetAddressString(fmt.Sprintf("127.0.0.1:%v", node.listenPort))
assert.Nil(t, err)
conn, err := grpc.Dial(netAddr.String(), grpc.WithTransportCredentials(cliCreds))
assert.Nil(t, err)
assert.NotNil(t, conn)
conn.Close()
conn, err = grpc.Dial(netAddr.String())
assert.NotNil(t, err)
t.Log("without creds", err)
assert.Nil(t, conn)
conn, err = grpc.Dial(netAddr.String(), grpc.WithInsecure())
assert.Nil(t, err)
assert.NotNil(t, conn)
_, err = netAddr.DialTimeout(0, cliCreds, nil)
assert.NotNil(t, err)
t.Log(err.Error())
cp = x509.NewCertPool()
if !cp.AppendCertsFromPEM([]byte(cert)) {
return
}
cliCreds = credentials.NewClientTLSFromCert(cp, "")
_, err = netAddr.DialTimeout(0, cliCreds, nil)
assert.NotNil(t, err)
}
......@@ -110,7 +110,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) {
for _, peer := range peers {
//获取远程 peer invs
resp, err := peer.mconn.gcli.GetMemPool(context.Background(),
&pb.P2PGetMempool{Version: m.network.node.nodeInfo.channelVersion}, grpc.FailFast(true))
&pb.P2PGetMempool{Version: m.network.node.nodeInfo.channelVersion}, grpc.WaitForReady(false))
P2pComm.CollectPeerStat(err, peer)
if err != nil {
if err == pb.ErrVersion {
......@@ -142,7 +142,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) {
}
//获取真正的交易Tx call GetData
datacli, dataerr := peer.mconn.gcli.GetData(context.Background(),
&pb.P2PGetData{Invs: ableInv, Version: m.network.node.nodeInfo.channelVersion}, grpc.FailFast(true))
&pb.P2PGetData{Invs: ableInv, Version: m.network.node.nodeInfo.channelVersion}, grpc.WaitForReady(false))
P2pComm.CollectPeerStat(dataerr, peer)
if dataerr != nil {
continue
......@@ -174,7 +174,7 @@ func (m *Cli) GetMemPool(msg *queue.Message, taskindex int64) {
func (m *Cli) GetAddr(peer *Peer) ([]string, error) {
resp, err := peer.mconn.gcli.GetAddr(context.Background(), &pb.P2PGetAddr{Nonce: int64(rand.Int31n(102040))},
grpc.FailFast(true))
grpc.WaitForReady(false))
P2pComm.CollectPeerStat(err, peer)
if err != nil {
return nil, err
......@@ -192,7 +192,7 @@ func (m *Cli) GetInPeersNum(peer *Peer) (int, error) {
}
resp, err := peer.mconn.gcli.CollectInPeers(context.Background(), ping,
grpc.FailFast(true))
grpc.WaitForReady(false))
P2pComm.CollectPeerStat(err, peer)
if err != nil {
......@@ -210,7 +210,7 @@ func (m *Cli) GetAddrList(peer *Peer) (map[string]*pb.P2PPeerInfo, error) {
return addrlist, fmt.Errorf("pointer is nil")
}
resp, err := peer.mconn.gcli.GetAddrList(context.Background(), &pb.P2PGetAddr{Nonce: int64(rand.Int31n(102040))},
grpc.FailFast(true))
grpc.WaitForReady(false))
P2pComm.CollectPeerStat(err, peer)
if err != nil {
......@@ -272,7 +272,7 @@ func (m *Cli) SendVersion(peer *Peer, nodeinfo *NodeInfo) (string, error) {
resp, err := peer.mconn.gcli.Version2(context.Background(), &pb.P2PVersion{Version: nodeinfo.channelVersion, Service: int64(nodeinfo.ServiceTy()), Timestamp: pb.Now().Unix(),
AddrRecv: peer.Addr(), AddrFrom: addrfrom, Nonce: int64(rand.Int31n(102040)),
UserAgent: hex.EncodeToString(in.Sign.GetPubkey()), StartHeight: blockheight}, grpc.FailFast(true))
UserAgent: hex.EncodeToString(in.Sign.GetPubkey()), StartHeight: blockheight}, grpc.WaitForReady(false))
log.Debug("SendVersion", "resp", resp, "from", addrfrom, "to", peer.Addr())
if err != nil {
log.Error("SendVersion", "Verson", err.Error(), "peer", peer.Addr())
......@@ -295,7 +295,7 @@ func (m *Cli) SendVersion(peer *Peer, nodeinfo *NodeInfo) (string, error) {
log.Debug("sendVersion", "expect ip", ip, "pre externalip", nodeinfo.GetExternalAddr().IP.String())
if peer.IsPersistent() {
//永久加入黑名单
nodeinfo.blacklist.Add(ip, 0)
nodeinfo.blacklist.Add(resp.GetAddrRecv(), 0) //把自己的IP:PORT 加入黑名单,防止连接到自己
}
}
}
......@@ -317,7 +317,7 @@ func (m *Cli) SendPing(peer *Peer, nodeinfo *NodeInfo) error {
return err
}
r, err := peer.mconn.gcli.Ping(context.Background(), ping, grpc.FailFast(true))
r, err := peer.mconn.gcli.Ping(context.Background(), ping, grpc.WaitForReady(false))
P2pComm.CollectPeerStat(err, peer)
if err != nil {
return err
......@@ -395,7 +395,7 @@ func (m *Cli) GetHeaders(msg *queue.Message, taskindex int64) {
if peer, ok := peers[pid[0]]; ok && peer != nil {
var err error
headers, err := peer.mconn.gcli.GetHeaders(context.Background(), &pb.P2PGetHeaders{StartHeight: req.GetStart(), EndHeight: req.GetEnd(),
Version: m.network.node.nodeInfo.channelVersion}, grpc.FailFast(true))
Version: m.network.node.nodeInfo.channelVersion}, grpc.WaitForReady(false))
P2pComm.CollectPeerStat(err, peer)
if err != nil {
log.Error("GetBlocks", "Err", err.Error())
......@@ -567,7 +567,7 @@ func (m *Cli) GetNetInfo(msg *queue.Message, taskindex int64) {
// CheckPeerNatOk check peer is ok or not
func (m *Cli) CheckPeerNatOk(addr string, info *NodeInfo) bool {
//连接自己的地址信息做测试
return !(len(P2pComm.AddrRouteble([]string{addr}, info.channelVersion, info.cliCreds)) == 0)
return !(len(P2pComm.AddrRouteble([]string{addr}, info.channelVersion, info.cliCreds, info.blacklist)) == 0)
}
......@@ -579,7 +579,7 @@ func (m *Cli) CheckSelf(addr string, nodeinfo *NodeInfo) bool {
return false
}
conn, err := netaddr.DialTimeout(nodeinfo.channelVersion, nodeinfo.cliCreds)
conn, err := netaddr.DialTimeout(nodeinfo.channelVersion, nodeinfo.cliCreds, nodeinfo.blacklist)
if err != nil {
return false
}
......@@ -587,7 +587,7 @@ func (m *Cli) CheckSelf(addr string, nodeinfo *NodeInfo) bool {
cli := pb.NewP2PgserviceClient(conn)
resp, err := cli.GetPeerInfo(context.Background(),
&pb.P2PGetPeerInfo{Version: nodeinfo.channelVersion}, grpc.FailFast(true))
&pb.P2PGetPeerInfo{Version: nodeinfo.channelVersion}, grpc.WaitForReady(false))
if err != nil {
return false
}
......
......@@ -170,7 +170,7 @@ func (p *Peer) GetInBouns() int32 {
// GetPeerInfo get peer information of peer
func (p *Peer) GetPeerInfo() (*pb.P2PPeerInfo, error) {
return p.mconn.gcli.GetPeerInfo(context.Background(), &pb.P2PGetPeerInfo{Version: p.node.nodeInfo.channelVersion}, grpc.FailFast(true))
return p.mconn.gcli.GetPeerInfo(context.Background(), &pb.P2PGetPeerInfo{Version: p.node.nodeInfo.channelVersion}, grpc.WaitForReady(false))
}
func (p *Peer) sendStream() {
......@@ -299,7 +299,7 @@ func (p *Peer) readStream() {
log.Error("readStream", "err:", err.Error(), "peerIp", p.Addr())
continue
}
resp, err := p.mconn.gcli.ServerStreamSend(context.Background(), ping, grpc.WaitForReady(true))
resp, err := p.mconn.gcli.ServerStreamSend(context.Background(), ping, grpc.WaitForReady(false))
P2pComm.CollectPeerStat(err, p)
if err != nil {
log.Error("readStream", "serverstreamsend,err:", err, "peer", p.Addr())
......
package gossip
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
"strings"
"sync"
"syscall"
"google.golang.org/grpc/credentials"
)
//Tls defines the specific interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL).
type Tls struct {
config *tls.Config
}
type certInfo struct {
revoke bool
ip string
serial string
}
var (
serials = make(map[string]*certInfo)
revokeLock sync.Mutex
latestSerials sync.Map
)
//serialNum -->ip
func addCertSerial(serial string, ip string) {
revokeLock.Lock()
defer revokeLock.Unlock()
serials[serial] = &certInfo{false, ip, serial}
}
func updateCertSerial(serial string, revoke bool) certInfo {
revokeLock.Lock()
defer revokeLock.Unlock()
v, ok := serials[serial]
if ok {
v.revoke = revoke
return *v
}
return certInfo{}
}
func isRevoke(serial string) bool {
revokeLock.Lock()
defer revokeLock.Unlock()
if r, ok := serials[serial]; ok {
return r.revoke
}
return false
}
func removeCertSerial(serial string) {
revokeLock.Lock()
defer revokeLock.Unlock()
delete(serials, serial)
}
func getSerialNums() []string {
revokeLock.Lock()
defer revokeLock.Unlock()
var certs []string
for s := range serials {
certs = append(certs, s)
}
return certs
}
func getSerials() map[string]*certInfo {
revokeLock.Lock()
defer revokeLock.Unlock()
var certs = make(map[string]*certInfo)
for k, v := range serials {
certs[k] = v
}
return certs
}
func (c Tls) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{
SecurityProtocol: "tls",
SecurityVersion: "1.2",
ServerName: c.config.ServerName,
}
}
func CloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}
func (c *Tls) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := CloneTLSConfig(c.config)
if cfg.ServerName == "" {
serverName, _, err := net.SplitHostPort(authority)
if err != nil {
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
serverName = authority
}
cfg.ServerName = serverName
}
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
go func() {
errChannel <- conn.Handshake()
close(errChannel)
}()
select {
case err := <-errChannel:
if err != nil {
conn.Close()
return nil, nil, err
}
case <-ctx.Done():
conn.Close()
return nil, nil, ctx.Err()
}
tlsInfo := credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}
peerCert := tlsInfo.State.PeerCertificates
//校验CERT
certNum := len(peerCert)
if certNum > 0 {
peerSerialNum := peerCert[0].SerialNumber
log.Debug("ClientHandshake", "Certificate SerialNumber", peerSerialNum, "Certificate Number", certNum, "RemoteAddr", rawConn.RemoteAddr(), "tlsInfo", tlsInfo)
addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":")
//检查证书是否被吊销
if isRevoke(peerSerialNum.String()) {
conn.Close()
return nil, nil, errors.New(fmt.Sprintf("transport: authentication handshake failed: ClientHandshake Certificate SerialNumber %v revoked", peerSerialNum.String()))
}
if len(addrSplites) > 0 { //服务端证书的序列号,已经其IP地址
addCertSerial(peerSerialNum.String(), addrSplites[0])
latestSerials.Store(addrSplites[0], peerSerialNum.String()) //ip --->serialNum
}
}
id := SPIFFEIDFromState(conn.ConnectionState())
if id != nil {
tlsInfo.SPIFFEID = id
}
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
}
//ServerHandshake check cert
func (c *Tls) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil {
conn.Close()
return nil, nil, err
}
tlsInfo := credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}
peerCert := tlsInfo.State.PeerCertificates
//校验CERT
certNum := len(peerCert)
if certNum != 0 {
peerSerialNum := peerCert[0].SerialNumber
log.Debug("ServerHandshake", "peerSerialNum", peerSerialNum, "Certificate Number", certNum, "RemoteAddr", rawConn.RemoteAddr(), "tlsinfo", tlsInfo, "remoteAddr", conn.RemoteAddr())
if isRevoke(peerSerialNum.String()) {
rawConn.Close()
return nil, nil, errors.New(fmt.Sprintf("transport: authentication handshake failed: ServerHandshake %s revoked", peerSerialNum.String()))
}
addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":")
if len(addrSplites) > 0 {
addCertSerial(peerSerialNum.String(), addrSplites[0])
latestSerials.Store(addrSplites[0], peerSerialNum.String()) //ip --->serialNum
}
} else {
log.Debug("ServerHandshake", "info", tlsInfo)
}
id := SPIFFEIDFromState(conn.ConnectionState())
if id != nil {
tlsInfo.SPIFFEID = id
}
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
}
// uses c to construct a TransportCredentials based on TLS.
func newTLS(c *tls.Config) credentials.TransportCredentials {
tc := &Tls{}
tc.config = CloneTLSConfig(c)
//tc.serials=make(map[*big.Int]*certInfo)
tc.config.NextProtos = AppendH2ToNextProtos(tc.config.NextProtos)
return tc
}
func (c *Tls) Clone() credentials.TransportCredentials {
return newTLS(c.config)
}
func (c *Tls) OverrideServerName(serverNameOverride string) error {
c.config.ServerName = serverNameOverride
return nil
}
func SPIFFEIDFromState(state tls.ConnectionState) *url.URL {
if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 {
return nil
}
return SPIFFEIDFromCert(state.PeerCertificates[0])
}
// SPIFFEIDFromCert parses the SPIFFE ID from x509.Certificate. If the SPIFFE
// ID format is invalid, return nil with warning.
func SPIFFEIDFromCert(cert *x509.Certificate) *url.URL {
if cert == nil || cert.URIs == nil {
return nil
}
var spiffeID *url.URL
for _, uri := range cert.URIs {
if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") {
continue
}
// From this point, we assume the uri is intended for a SPIFFE ID.
if len(uri.String()) > 2048 {
//logger.Warning("invalid SPIFFE ID: total ID length larger than 2048 bytes")
return nil
}
if len(uri.Host) == 0 || len(uri.Path) == 0 {
//logger.Warning("invalid SPIFFE ID: domain or workload ID is empty")
return nil
}
if len(uri.Host) > 255 {
//logger.Warning("invalid SPIFFE ID: domain length larger than 255 characters")
return nil
}
// A valid SPIFFE certificate can only have exactly one URI SAN field.
if len(cert.URIs) > 1 {
//logger.Warning("invalid SPIFFE ID: multiple URI SANs")
return nil
}
spiffeID = uri
}
return spiffeID
}
type sysConn = syscall.Conn
type syscallConn struct {
net.Conn
// sysConn is a type alias of syscall.Conn. It's necessary because the name
// `Conn` collides with `net.Conn`.
sysConn
}
func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn {
sysConn, ok := rawConn.(syscall.Conn)
if !ok {
return newConn
}
return &syscallConn{
Conn: newConn,
sysConn: sysConn,
}
}
const alpnProtoStrH2 = "h2"
// AppendH2ToNextProtos appends h2 to next protos.
func AppendH2ToNextProtos(ps []string) []string {
for _, p := range ps {
if p == alpnProtoStrH2 {
return ps
}
}
ret := make([]string, 0, len(ps)+1)
ret = append(ret, ps...)
return append(ret, alpnProtoStrH2)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment