Commit ba502228 authored by libangzhu's avatar libangzhu

增加ca认证下的tls通信

parent c81872d0
......@@ -58,7 +58,7 @@ func (Comm) ParaseNetAddr(addr string) (string, int64, error) {
}
// AddrRouteble address router ,return enbale address
func (Comm) AddrRouteble(addrs []string, version int32, creds credentials.TransportCredentials) []string {
func (Comm) AddrRouteble(addrs []string, version int32, creds credentials.TransportCredentials, blist *BlackList) []string {
var enableAddrs []string
for _, addr := range addrs {
......@@ -67,7 +67,7 @@ func (Comm) AddrRouteble(addrs []string, version int32, creds credentials.Transp
log.Error("AddrRouteble", "NewNetAddressString", err.Error())
continue
}
conn, err := netaddr.DialTimeout(version, creds)
conn, err := netaddr.DialTimeout(version, creds, blist)
if err != nil {
//log.Error("AddrRouteble", "DialTimeout", err.Error())
continue
......@@ -110,15 +110,18 @@ func (c Comm) GetLocalAddr() string {
func (c Comm) dialPeerWithAddress(addr *NetAddress, persistent bool, node *Node) (*Peer, error) {
log.Debug("dialPeerWithAddress")
conn, err := addr.DialTimeout(node.nodeInfo.channelVersion, node.nodeInfo.cliCreds)
conn, err := addr.DialTimeout(node.nodeInfo.channelVersion, node.nodeInfo.cliCreds, node.nodeInfo.blacklist)
if err != nil {
log.Error("dialPeerWithAddress","DialTimeoutErr",err.Error())
return nil, err
}
peer, err := c.newPeerFromConn(conn, addr, node)
if err != nil {
log.Error("dialPeerWithAddress","newPeerFromConn",err)
err = conn.Close()
return nil, err
}
peer.SetAddr(addr)
......@@ -151,7 +154,6 @@ func (c Comm) dialPeerWithAddress(addr *NetAddress, persistent bool, node *Node)
peer.Close()
return nil, errors.New(fmt.Sprintf("duplicate connect %v", resp.UserAgent))
}
node.peerStore.Store(addr.String(), resp.UserAgent)
peer.SetPeerName(resp.UserAgent)
return peer, nil
......
......@@ -25,6 +25,7 @@ var (
CheckActivePeersInterVal = 5 * time.Second
CheckBlackListInterVal = 30 * time.Second
CheckCfgSeedsInterVal = 1 * time.Minute
CheckCfgCertInterVal = 30 * time.Second
)
const (
......
......@@ -6,6 +6,7 @@ package gossip
import (
"fmt"
"math/big"
"math/rand"
"net"
"sync"
......@@ -91,6 +92,12 @@ Retry:
if err != nil {
return nil, err
}
if serialNum,ok:= latestSerials.Load(ip);ok{
bn,_:=big.NewInt(1).SetString(serialNum.(string),10)
if isRevoke(bn){//证书被吊销 拒绝接口请求
return nil, fmt.Errorf("cert %v revoked", serialNum.(string))
}
}
if pServer.node.nodeInfo.blacklist.Has(ip) {
return nil, fmt.Errorf("blacklist %v no authorized", ip)
}
......@@ -116,6 +123,13 @@ Retry:
if err != nil {
return err
}
if serialNum,ok:= latestSerials.Load(ip);ok{
bn,_:=big.NewInt(1).SetString(serialNum.(string),10)
if isRevoke(bn){//证书被吊销 拒绝接口请求
return fmt.Errorf("cert %v revoked", serialNum.(string))
}
}
if pServer.node.nodeInfo.blacklist.Has(ip) {
return fmt.Errorf("blacklist %v no authorized", ip)
}
......@@ -146,12 +160,12 @@ Retry:
opts = append(opts, msgRecvOp, msgSendOp, grpc.KeepaliveEnforcementPolicy(kaep), keepOp, maxStreams, StatsOp)
if node.nodeInfo.servCreds != nil {
opts = append(opts, grpc.Creds(node.nodeInfo.servCreds))
}
dl.server = grpc.NewServer(opts...)
dl.p2pserver = pServer
pb.RegisterP2PgserviceServer(dl.server, pServer)
return dl
}
......
......@@ -6,13 +6,14 @@ package gossip
import (
"bytes"
"github.com/33cn/chain33/rpc/jsonclient"
"io"
"math/big"
"net/http"
"strings"
"time"
"github.com/33cn/chain33/p2p/utils"
"github.com/33cn/chain33/types"
)
......@@ -599,3 +600,90 @@ func (n *Node) monitorCfgSeeds() {
}
}
func (n *Node) monitorCerts() {
if !n.nodeInfo.cfg.EnableTls {
return
}
ticker := time.NewTicker(CheckCfgCertInterVal)
defer ticker.Stop()
jcli, err := jsonclient.New("chain33-ca-server",n.nodeInfo.caServer , false)
if err != nil {
log.Error("monitorCerts", "rpc call err", err)
return
}
delayT:=time.Now().Add(time.Minute*2)
for {
select {
case <-ticker.C:
//check serialNum
if !time.Now().After(delayT){
continue
}
var resp []string
var s Serial
s.Serials =getSerialNums()
if len(s.Serials) == 0 {
continue
}
log.Debug("check cert serialNum++++++","certNum.",len(s.Serials ))
err = jcli.Call("Validate", s, &resp)
if err != nil {
log.Error("monitorCerts", "rpc call err", err)
continue
}
log.Debug("monitorCerts","resp", resp)
tempCerts := getSerials()
for _, serialNum := range resp {
//被吊销的证书序列号
var ok bool
sNum := big.NewInt(1)
sNum, ok = sNum.SetString(serialNum, 10)
if !ok {
log.Error("monitorCerts", "big.Int Setstring err", serialNum)
continue
}
//设置证书序列号状态
certinfo := updateCertSerial(sNum, true)
delete(tempCerts, sNum.String())
if certinfo != nil {
//断开节点连接
//if ip-->serialNum == sNum{
// close connect
//}else{
// }
//log.Info("monitorCerts","add blacklist",certinfo.ip)
//n.nodeInfo.blacklist.Add(certinfo.ip, 60)
for pname,peer:=range n.nodeInfo.peerInfos.GetPeerInfos(){
if peer.GetAddr()==certinfo.ip {
v,ok:= latestSerials.Load(certinfo.ip)
if ok && v.(string)==serialNum{
n.remove(pname)//断开已经连接的节点
}
}
}
}
}
log.Debug("monitorCert","tempCerts",tempCerts)
//处理解除吊销的节点
for serialNum, info := range tempCerts {
if info.revoke {
// 被撤销的证书恢复正常
sNum := big.NewInt(1)
sNum, _ = sNum.SetString(serialNum, 10)
updateCertSerial(sNum, !info.revoke)
}
/*
//拉入黑名单的节点 恢复正常
if n.nodeInfo.blacklist.Has(info.ip) {
n.nodeInfo.blacklist.Delete(info.ip)
}*/
}
}
}
}
......@@ -7,6 +7,8 @@ package gossip
import (
"context"
"fmt"
pr "google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"net"
"strconv"
"time"
......@@ -136,14 +138,14 @@ func (na *NetAddress) Copy() *NetAddress {
// DialTimeout calls net.DialTimeout on the address.
func isCompressSupport(err error) bool {
var errstr = `grpc: Decompressor is not installed for grpc-encoding "gzip"`
if grpc.Code(err) == codes.Unimplemented && grpc.ErrorDesc(err) == errstr {
if status.Code(err) == codes.Unimplemented && status.Convert(err).Message() == errstr {
return false
}
return true
}
// DialTimeout dial timeout
func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCredentials) (*grpc.ClientConn, error) {
func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCredentials, bList *BlackList) (*grpc.ClientConn, error) {
ch := make(chan grpc.ServiceConfig, 1)
ch <- P2pComm.GrpcConfig()
......@@ -152,8 +154,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
cliparm.Timeout = 10 * time.Second //ping后的获取ack消息超时时间
cliparm.PermitWithoutStream = true //启动keepalive 进行检查
keepaliveOp := grpc.WithKeepaliveParams(cliparm)
timeoutOp := grpc.WithTimeout(time.Second * 3)
log.Debug("NetAddress", "Dial", na.String())
log.Info("NetAddress", "Dial------------->", na.String())
maxMsgSize := pb.MaxBlockSize + 1024*1024
//配置SSL连接
var secOpt grpc.DialOption
......@@ -162,16 +163,63 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
} else {
secOpt = grpc.WithTransportCredentials(creds)
}
//接口拦截器
interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// 黑名单校验
//checkAuth
log.Info("client interceptor")
ip, _, err := net.SplitHostPort(na.String())
if err != nil {
return err
}
log.Info("interceptor client","remoteAddr",na.String())
if bList != nil && bList.Has(ip) {
return fmt.Errorf("blacklist peer %v no authorized", ip)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
//流拦截器
interceptorStream := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
log.Info("client interceptorStream")
ip, _, err := net.SplitHostPort(na.String())
if err != nil {
return nil, err
}
log.Info("interceptorStream client","remoteAddr",na.String())
if bList.Has(ip) {
return nil, fmt.Errorf("blacklist peer %v no authorized", ip)
}
return streamer(ctx, desc, cc, method, opts...)
}
//grpc.WithPerRPCCredentials
conn, err := grpc.Dial(na.String(),
tcpAddr,err:= net.ResolveTCPAddr("tcp",na.String())
if err!=nil{
return nil, err
}
peer := &pr.Peer{
Addr:tcpAddr,
AuthInfo: nil,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
ctxV:= pr.NewContext(ctx,peer)
conn, err := grpc.DialContext(ctxV, na.String(),
grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)),
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)),
grpc.WithServiceConfig(ch), keepaliveOp, timeoutOp, secOpt)
grpc.WithServiceConfig(ch), keepaliveOp, secOpt,
grpc.WithUnaryInterceptor(interceptor), grpc.WithStreamInterceptor(interceptorStream))
if err != nil {
log.Debug("grpc DialCon", "did not connect", err, "addr", na.String())
log.Error("grpc DialCon", "did not connect", err, "addr", na.String())
return nil, err
}
//p2p version check 通过版本协议,获取通信session
//判断是否对方是否支持压缩
......@@ -187,7 +235,7 @@ func (na *NetAddress) DialTimeout(version int32, creds credentials.TransportCred
ch2 := make(chan grpc.ServiceConfig, 1)
ch2 <- P2pComm.GrpcConfig()
log.Debug("NetAddress", "Dial with unCompressor", na.String())
conn, err = grpc.Dial(na.String(), secOpt, grpc.WithServiceConfig(ch2), keepaliveOp, timeoutOp)
conn, err = grpc.DialContext(ctx, na.String(), secOpt, grpc.WithServiceConfig(ch2), keepaliveOp)
}
......
......@@ -5,11 +5,12 @@
package gossip
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"math/rand"
"google.golang.org/grpc/credentials"
"github.com/33cn/chain33/p2p"
//"strings"
......@@ -85,6 +86,7 @@ type Node struct {
pubsub *pubsub.PubSub
chainCfg *types.Chain33Config
p2pMgr *p2p.Manager
//tls *Tls
}
// SetQueueClient return client for nodeinfo
......@@ -102,6 +104,7 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) {
pubsub: pubsub.NewPubSub(10200),
p2pMgr: mgr,
}
//node.tls = &Tls{serials:make(map[*big.Int]*certInfo)}
node.listenPort = 13802
if mcfg.Port != 0 && mcfg.Port <= 65535 && mcfg.Port > 1024 {
node.listenPort = int(mcfg.Port)
......@@ -126,18 +129,43 @@ func NewNode(mgr *p2p.Manager, mcfg *subConfig) (*Node, error) {
node.chainCfg = cfg
if mcfg.EnableTls { //读取证书,初始化tls客户端
var err error
node.nodeInfo.cliCreds, err = credentials.NewClientTLSFromFile(cfg.GetModuleConfig().RPC.CertFile, "")
cert, err := tls.LoadX509KeyPair(mcfg.CertFile, mcfg.KeyFile)
if err != nil {
panic(err)
}
node.nodeInfo.servCreds, err = credentials.NewServerTLSFromFile(cfg.GetModuleConfig().RPC.CertFile, cfg.GetModuleConfig().RPC.KeyFile)
certPool := x509.NewCertPool()
//添加CA校验
//把CA证书读进去,尝试动态更新CA中的吊销列表
ca, err := ioutil.ReadFile(mcfg.CaCert)
if err != nil {
panic(err)
}
if ok := certPool.AppendCertsFromPEM(ca); !ok {
panic("certPool.AppendCertsFromPEM err")
}
node.nodeInfo.servCreds = newTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert, //校验客户端证书,用ca.pem校验
ClientCAs: certPool,
})
// 构建基于 TLS 的 TransportCredentials 选项
// 在 Client 请求 Server 端时,Client 端会使用根证书和 ServerName 去对 Server 端进行校验
node.nodeInfo.cliCreds = newTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
ServerName: "",
RootCAs: certPool,
})
node.nodeInfo.caServer=mcfg.CaServer
}
if mcfg.ServerStart {
node.server = newListener(protocol, node)
}
return node, nil
}
......@@ -171,7 +199,7 @@ func (n *Node) doNat() {
}
testExaddr := fmt.Sprintf("%v:%v", n.nodeInfo.GetExternalAddr().IP.String(), n.listenPort)
log.Info("TestNetAddr", "testExaddr", testExaddr)
if len(P2pComm.AddrRouteble([]string{testExaddr}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds)) != 0 {
if len(P2pComm.AddrRouteble([]string{testExaddr}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds, n.nodeInfo.blacklist)) != 0 {
log.Info("node outside")
n.nodeInfo.SetNetSide(true)
if netexaddr, err := NewNetAddressString(testExaddr); err == nil {
......@@ -374,6 +402,7 @@ func (n *Node) monitor() {
go n.monitorFilter()
go n.monitorPeers()
go n.nodeReBalance()
go n.monitorCerts()
}
func (n *Node) needMore() bool {
......@@ -454,7 +483,7 @@ func (n *Node) natMapPort() {
time.Sleep(time.Second)
}
var err error
if len(P2pComm.AddrRouteble([]string{n.nodeInfo.GetExternalAddr().String()}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds)) != 0 { //判断能否连通要映射的端口
if len(P2pComm.AddrRouteble([]string{n.nodeInfo.GetExternalAddr().String()}, n.nodeInfo.channelVersion, n.nodeInfo.cliCreds, n.nodeInfo.blacklist)) != 0 { //判断能否连通要映射的端口
log.Info("natMapPort", "addr", "routeble")
p2pcli := NewNormalP2PCli() //检查要映射的IP地址是否已经被映射成功
ok := p2pcli.CheckSelf(n.nodeInfo.GetExternalAddr().String(), n.nodeInfo)
......
......@@ -36,6 +36,7 @@ type NodeInfo struct {
channelVersion int32
cliCreds credentials.TransportCredentials
servCreds credentials.TransportCredentials
caServer string
}
// NewNodeInfo new a node object
......
......@@ -61,6 +61,12 @@ type subConfig struct {
MinLtBlockSize int32 `protobuf:"varint,12,opt,name=minLtBlockSize" json:"minLtBlockSize,omitempty"`
//是否使用证书进行节点之间的通信,true 使用证书通信,读取rpc配置项下的证书文件
EnableTls bool `protobuf:"varint,13,opt,name=enableTls" json:"enableTls,omitempty"`
CaCert string `json:"caCert,omitempty"`
CaServer string `json:"caServer,omitempty"`
CertFile string `json:"certFile,omitempty"`
// 私钥文件
KeyFile string `json:"keyFile,omitempty"`
}
// P2p interface
......
......@@ -567,7 +567,7 @@ func (m *Cli) GetNetInfo(msg *queue.Message, taskindex int64) {
// CheckPeerNatOk check peer is ok or not
func (m *Cli) CheckPeerNatOk(addr string, info *NodeInfo) bool {
//连接自己的地址信息做测试
return !(len(P2pComm.AddrRouteble([]string{addr}, info.channelVersion, info.cliCreds)) == 0)
return !(len(P2pComm.AddrRouteble([]string{addr}, info.channelVersion, info.cliCreds, info.blacklist)) == 0)
}
......@@ -579,7 +579,7 @@ func (m *Cli) CheckSelf(addr string, nodeinfo *NodeInfo) bool {
return false
}
conn, err := netaddr.DialTimeout(nodeinfo.channelVersion, nodeinfo.cliCreds)
conn, err := netaddr.DialTimeout(nodeinfo.channelVersion, nodeinfo.cliCreds, nodeinfo.blacklist)
if err != nil {
return false
}
......
......@@ -55,6 +55,7 @@ type Peer struct {
taskChan chan interface{} //tx block
inBounds int32 //连接此节点的客户端节点数量
IsMaxInbouds bool
serialNnum string
}
// NewPeer produce a peer object
......
package gossip
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"google.golang.org/grpc/credentials"
"math/big"
"net"
"net/url"
"strings"
"sync"
"syscall"
)
var serials = make(map[string]*certInfo)
var latestSerials sync.Map
var revokeLock sync.Mutex
type Tls struct {
config *tls.Config
}
type certInfo struct {
revoke bool
ip string
serial string
}
type Serial struct {
Serials []string `json:"serials,omitempty"`
}
//serialNum -->ip
func addCertSerial(serial *big.Int, ip string) {
revokeLock.Lock()
defer revokeLock.Unlock()
serials[serial.String()] = &certInfo{false, ip,serial.String()}
}
func updateCertSerial(serial *big.Int, revoke bool) *certInfo {
revokeLock.Lock()
defer revokeLock.Unlock()
v, ok := serials[serial.String()]
if ok {
v.revoke = revoke
}else{
return nil
}
serials[serial.String()] = v
return v
}
func isRevoke(serial *big.Int) bool {
revokeLock.Lock()
defer revokeLock.Unlock()
if r, ok := serials[serial.String()]; ok {
return r.revoke
}
return false
}
func removeCertSerial(serial *big.Int) {
revokeLock.Lock()
defer revokeLock.Unlock()
delete(serials, serial.String())
}
func getSerialNums() []string {
revokeLock.Lock()
defer revokeLock.Unlock()
var certs []string
for s := range serials {
certs = append(certs, s)
}
return certs
}
func getSerials() map[string]*certInfo {
revokeLock.Lock()
defer revokeLock.Unlock()
var certs = make(map[string]*certInfo)
for k, v := range serials {
certs[k] = v
}
return certs
}
func (c Tls) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{
SecurityProtocol: "tls",
SecurityVersion: "1.2",
ServerName: c.config.ServerName,
}
}
func CloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}
func (c *Tls) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := CloneTLSConfig(c.config)
if cfg.ServerName == "" {
serverName, _, err := net.SplitHostPort(authority)
if err != nil {
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
serverName = authority
}
cfg.ServerName = serverName
}
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
go func() {
errChannel <- conn.Handshake()
close(errChannel)
}()
select {
case err := <-errChannel:
if err != nil {
conn.Close()
return nil, nil, err
}
case <-ctx.Done():
conn.Close()
return nil, nil, ctx.Err()
}
tlsInfo := credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}
peerCert := tlsInfo.State.PeerCertificates
//校验CERT
certNum := len(peerCert)
if certNum > 0 {
peerSerialNum := peerCert[0].SerialNumber
log.Debug("ClientHandshake", "peerSerialNum", peerSerialNum, "certificate Num", certNum, "remoteAddr", rawConn.RemoteAddr(), "tlsInfo", tlsInfo)
addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":")
//检查证书是否被吊销
if isRevoke(peerSerialNum){
conn.Close()
return nil,nil,errors.New(fmt.Sprintf("tls ClientHandshake %v revoked",peerSerialNum.String()))
}
if len(addrSplites) > 0 { //服务端证书的序列号,已经其IP地址
addCertSerial(peerSerialNum, addrSplites[0])
latestSerials.Store(addrSplites[0],peerSerialNum.String())//ip --->serialNum
}
}
id := SPIFFEIDFromState(conn.ConnectionState())
if id != nil {
tlsInfo.SPIFFEID = id
}
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
}
//ServerHandshake check cert
func (c *Tls) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil {
conn.Close()
return nil, nil, err
}
tlsInfo := credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}
peerCert := tlsInfo.State.PeerCertificates
//校验CERT
certNum := len(peerCert)
if certNum != 0 {
peerSerialNum := peerCert[0].SerialNumber
//log.Info("ServerHandshake","certinfo",string(tlsInfo.State.PeerCertificates[0].Raw))
log.Debug("ServerHandshake", "peerSerialNum", peerSerialNum, "certificate Num", certNum, "remoteAddr", rawConn.RemoteAddr(), "tlsinfo", tlsInfo,"remoteAddr",conn.RemoteAddr())
if isRevoke(peerSerialNum) {
rawConn.Close()
return nil, nil, errors.New(fmt.Sprintf( "tls ServerHandshake %s revoked", peerSerialNum.String()))
}
addrSplites := strings.Split(rawConn.RemoteAddr().String(), ":")
if len(addrSplites) > 0 {
addCertSerial(peerSerialNum, addrSplites[0])
latestSerials.Store(addrSplites[0],peerSerialNum.String())//ip --->serialNum
}
} else {
log.Debug("ServerHandshake", "info", tlsInfo)
}
id := SPIFFEIDFromState(conn.ConnectionState())
if id != nil {
tlsInfo.SPIFFEID = id
}
return WrapSyscallConn(rawConn, conn), tlsInfo, nil
}
// uses c to construct a TransportCredentials based on TLS.
func newTLS(c *tls.Config) credentials.TransportCredentials {
tc := &Tls{}
tc.config = CloneTLSConfig(c)
//tc.serials=make(map[*big.Int]*certInfo)
tc.config.NextProtos = AppendH2ToNextProtos(tc.config.NextProtos)
return tc
}
//func upgradeTls(c *tls.Config,tc *Tls)credentials.TransportCredentials{
//
// tc.config=CloneTLSConfig(c)
// if tc.server{
// tc.serConf=tc.config
// }else{
// tc.cliConf=tc.config
// }
// tc.config.NextProtos = AppendH2ToNextProtos(tc.config.NextProtos)
// return tc
//}
func (c *Tls) Clone() credentials.TransportCredentials {
return newTLS(c.config)
}
func (c *Tls) OverrideServerName(serverNameOverride string) error {
c.config.ServerName = serverNameOverride
return nil
}
func SPIFFEIDFromState(state tls.ConnectionState) *url.URL {
if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 {
return nil
}
return SPIFFEIDFromCert(state.PeerCertificates[0])
}
// SPIFFEIDFromCert parses the SPIFFE ID from x509.Certificate. If the SPIFFE
// ID format is invalid, return nil with warning.
func SPIFFEIDFromCert(cert *x509.Certificate) *url.URL {
if cert == nil || cert.URIs == nil {
return nil
}
var spiffeID *url.URL
for _, uri := range cert.URIs {
if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") {
continue
}
// From this point, we assume the uri is intended for a SPIFFE ID.
if len(uri.String()) > 2048 {
//logger.Warning("invalid SPIFFE ID: total ID length larger than 2048 bytes")
return nil
}
if len(uri.Host) == 0 || len(uri.Path) == 0 {
//logger.Warning("invalid SPIFFE ID: domain or workload ID is empty")
return nil
}
if len(uri.Host) > 255 {
//logger.Warning("invalid SPIFFE ID: domain length larger than 255 characters")
return nil
}
// A valid SPIFFE certificate can only have exactly one URI SAN field.
if len(cert.URIs) > 1 {
//logger.Warning("invalid SPIFFE ID: multiple URI SANs")
return nil
}
spiffeID = uri
}
return spiffeID
}
type sysConn = syscall.Conn
type syscallConn struct {
net.Conn
// sysConn is a type alias of syscall.Conn. It's necessary because the name
// `Conn` collides with `net.Conn`.
sysConn
}
func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn {
sysConn, ok := rawConn.(syscall.Conn)
if !ok {
return newConn
}
return &syscallConn{
Conn: newConn,
sysConn: sysConn,
}
}
const alpnProtoStrH2 = "h2"
// AppendH2ToNextProtos appends h2 to next protos.
func AppendH2ToNextProtos(ps []string) []string {
for _, p := range ps {
if p == alpnProtoStrH2 {
return ps
}
}
ret := make([]string, 0, len(ps)+1)
ret = append(ret, ps...)
return append(ret, alpnProtoStrH2)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment