Commit 2b61b2b1 authored by linj's avatar linj Committed by vipwzw

update chain33

parent 6252a9f1
package blockchain_test package blockchain
import ( import (
"testing" "testing"
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"github.com/33cn/chain33/blockchain"
dbm "github.com/33cn/chain33/common/db" dbm "github.com/33cn/chain33/common/db"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -20,7 +19,7 @@ func TestGetStoreUpgradeMeta(t *testing.T) { ...@@ -20,7 +19,7 @@ func TestGetStoreUpgradeMeta(t *testing.T) {
blockStoreDB := dbm.NewDB("blockchain", "leveldb", dir, 100) blockStoreDB := dbm.NewDB("blockchain", "leveldb", dir, 100)
blockStore := blockchain.NewBlockStore(nil, blockStoreDB, nil) blockStore := NewBlockStore(nil, blockStoreDB, nil)
require.NotNil(t, blockStore) require.NotNil(t, blockStore)
meta, err := blockStore.GetStoreUpgradeMeta() meta, err := blockStore.GetStoreUpgradeMeta()
...@@ -34,3 +33,105 @@ func TestGetStoreUpgradeMeta(t *testing.T) { ...@@ -34,3 +33,105 @@ func TestGetStoreUpgradeMeta(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, meta.Version, "1.0.0") require.Equal(t, meta.Version, "1.0.0")
} }
func TestSeqSaveAndGet(t *testing.T) {
dir, err := ioutil.TempDir("", "example")
assert.Nil(t, err)
defer os.RemoveAll(dir) // clean up
os.RemoveAll(dir) //删除已存在目录
blockStoreDB := dbm.NewDB("blockchain", "leveldb", dir, 100)
blockStore := NewBlockStore(nil, blockStoreDB, nil)
assert.NotNil(t, blockStore)
blockStore.saveSequence = true
blockStore.isParaChain = false
newBatch := blockStore.NewBatch(true)
seq, err := blockStore.saveBlockSequence(newBatch, []byte("s0"), 0, 1, 0)
assert.Nil(t, err)
assert.Equal(t, int64(0), seq)
err = newBatch.Write()
assert.Nil(t, err)
newBatch = blockStore.NewBatch(true)
seq, err = blockStore.saveBlockSequence(newBatch, []byte("s1"), 1, 1, 0)
assert.Nil(t, err)
assert.Equal(t, int64(1), seq)
err = newBatch.Write()
assert.Nil(t, err)
s, err := blockStore.LoadBlockLastSequence()
assert.Nil(t, err)
assert.Equal(t, int64(1), s)
s2, err := blockStore.GetBlockSequence(s)
assert.Nil(t, err)
assert.Equal(t, []byte("s1"), s2.Hash)
s3, err := blockStore.GetSequenceByHash([]byte("s1"))
assert.Nil(t, err)
assert.Equal(t, int64(1), s3)
}
func TestParaSeqSaveAndGet(t *testing.T) {
dir, err := ioutil.TempDir("", "example")
assert.Nil(t, err)
defer os.RemoveAll(dir) // clean up
os.RemoveAll(dir) //删除已存在目录
blockStoreDB := dbm.NewDB("blockchain", "leveldb", dir, 100)
blockStore := NewBlockStore(nil, blockStoreDB, nil)
assert.NotNil(t, blockStore)
blockStore.saveSequence = true
blockStore.isParaChain = true
newBatch := blockStore.NewBatch(true)
seq, err := blockStore.saveBlockSequence(newBatch, []byte("s0"), 0, 1, 1)
assert.Nil(t, err)
assert.Equal(t, int64(0), seq)
err = newBatch.Write()
assert.Nil(t, err)
newBatch = blockStore.NewBatch(true)
seq, err = blockStore.saveBlockSequence(newBatch, []byte("s1"), 1, 1, 10)
assert.Nil(t, err)
assert.Equal(t, int64(1), seq)
err = newBatch.Write()
assert.Nil(t, err)
s, err := blockStore.LoadBlockLastSequence()
assert.Nil(t, err)
assert.Equal(t, int64(1), s)
s2, err := blockStore.GetBlockSequence(s)
assert.Nil(t, err)
assert.Equal(t, []byte("s1"), s2.Hash)
s3, err := blockStore.GetSequenceByHash([]byte("s1"))
assert.Nil(t, err)
assert.Equal(t, int64(1), s3)
s4, err := blockStore.GetMainSequenceByHash([]byte("s1"))
assert.Nil(t, err)
assert.Equal(t, int64(10), s4)
s5, err := blockStore.LoadBlockLastMainSequence()
assert.Nil(t, err)
assert.Equal(t, int64(10), s5)
s6, err := blockStore.GetBlockByMainSequence(1)
assert.Nil(t, err)
assert.Equal(t, []byte("s0"), s6.Hash)
chain := &BlockChain{
blockStore: blockStore,
}
s7, err := chain.ProcGetMainSeqByHash([]byte("s0"))
assert.Nil(t, err)
assert.Equal(t, int64(1), s7)
_, err = chain.ProcGetMainSeqByHash([]byte("s0-not-exist"))
assert.NotNil(t, err)
}
...@@ -245,9 +245,8 @@ func (chain *BlockChain) GetOrphanPool() *OrphanPool { ...@@ -245,9 +245,8 @@ func (chain *BlockChain) GetOrphanPool() *OrphanPool {
//InitBlockChain 区块链初始化 //InitBlockChain 区块链初始化
func (chain *BlockChain) InitBlockChain() { func (chain *BlockChain) InitBlockChain() {
//isRecordBlockSequence配置的合法性检测 //isRecordBlockSequence配置的合法性检测
if !chain.cfg.IsParaChain { chain.blockStore.SequenceMustValid(chain.isRecordBlockSequence)
chain.blockStore.isRecordBlockSequenceValid(chain)
}
//先缓存最新的128个block信息到cache中 //先缓存最新的128个block信息到cache中
curheight := chain.GetBlockHeight() curheight := chain.GetBlockHeight()
if types.IsEnable("TxHeight") { if types.IsEnable("TxHeight") {
......
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"github.com/33cn/chain33/common/log" "github.com/33cn/chain33/common/log"
"github.com/33cn/chain33/common/log/log15" "github.com/33cn/chain33/common/log/log15"
"github.com/33cn/chain33/common/merkle" "github.com/33cn/chain33/common/merkle"
"github.com/33cn/chain33/queue"
_ "github.com/33cn/chain33/system" _ "github.com/33cn/chain33/system"
"github.com/33cn/chain33/types" "github.com/33cn/chain33/types"
"github.com/33cn/chain33/util" "github.com/33cn/chain33/util"
...@@ -129,6 +130,8 @@ func TestBlockChain(t *testing.T) { ...@@ -129,6 +130,8 @@ func TestBlockChain(t *testing.T) {
testReadBlockToExec(t, blockchain) testReadBlockToExec(t, blockchain)
testReExecBlock(t, blockchain) testReExecBlock(t, blockchain)
testUpgradeStore(t, blockchain) testUpgradeStore(t, blockchain)
testProcMainSeqMsg(t, blockchain)
} }
func testProcAddBlockMsg(t *testing.T, mock33 *testnode.Chain33Mock, blockchain *blockchain.BlockChain) { func testProcAddBlockMsg(t *testing.T, mock33 *testnode.Chain33Mock, blockchain *blockchain.BlockChain) {
...@@ -1124,3 +1127,17 @@ func testUpgradeStore(t *testing.T, chain *blockchain.BlockChain) { ...@@ -1124,3 +1127,17 @@ func testUpgradeStore(t *testing.T, chain *blockchain.BlockChain) {
chain.UpgradeStore() chain.UpgradeStore()
chainlog.Info("UpgradeStore end ---------------------") chainlog.Info("UpgradeStore end ---------------------")
} }
func testProcMainSeqMsg(t *testing.T, blockchain *blockchain.BlockChain) {
chainlog.Info("testProcMainSeqMsg begin -------------------")
msg := queue.NewMessage(1, "blockchain", types.EventGetLastBlockMainSequence, nil)
blockchain.GetLastBlockMainSequence(msg)
assert.Equal(t, int64(types.EventGetLastBlockMainSequence), msg.Ty)
msg = queue.NewMessage(1, "blockchain", types.EventGetMainSeqByHash, &types.ReqHash{Hash: []byte("hash")})
blockchain.GetMainSeqByHash(msg)
assert.Equal(t, int64(types.EventGetMainSeqByHash), msg.Ty)
chainlog.Info("testProcMainSeqMsg end --------------------")
}
...@@ -70,7 +70,6 @@ func (chain *BlockChain) ProcRecvMsg() { ...@@ -70,7 +70,6 @@ func (chain *BlockChain) ProcRecvMsg() {
case types.EventGetBlockByHashes: case types.EventGetBlockByHashes:
go chain.processMsg(msg, reqnum, chain.getBlockByHashes) go chain.processMsg(msg, reqnum, chain.getBlockByHashes)
case types.EventGetBlockBySeq: case types.EventGetBlockBySeq:
go chain.processMsg(msg, reqnum, chain.getBlockBySeq) go chain.processMsg(msg, reqnum, chain.getBlockBySeq)
...@@ -90,6 +89,12 @@ func (chain *BlockChain) ProcRecvMsg() { ...@@ -90,6 +89,12 @@ func (chain *BlockChain) ProcRecvMsg() {
case types.EventGetSeqCBLastNum: case types.EventGetSeqCBLastNum:
go chain.processMsg(msg, reqnum, chain.getSeqCBLastNum) go chain.processMsg(msg, reqnum, chain.getSeqCBLastNum)
case types.EventGetLastBlockMainSequence:
go chain.processMsg(msg, reqnum, chain.GetLastBlockMainSequence)
case types.EventGetMainSeqByHash:
go chain.processMsg(msg, reqnum, chain.GetMainSeqByHash)
default: default:
go chain.processMsg(msg, reqnum, chain.unknowMsg) go chain.processMsg(msg, reqnum, chain.unknowMsg)
} }
...@@ -519,3 +524,28 @@ func (chain *BlockChain) localAddrTxCount(msg *queue.Message) { ...@@ -519,3 +524,28 @@ func (chain *BlockChain) localAddrTxCount(msg *queue.Message) {
counts = count.Data counts = count.Data
msg.Reply(chain.client.NewMessage("rpc", types.EventLocalReplyValue, &types.Int64{Data: counts})) msg.Reply(chain.client.NewMessage("rpc", types.EventLocalReplyValue, &types.Int64{Data: counts}))
} }
//GetLastBlockMainSequence 获取最新的block执行序列号
func (chain *BlockChain) GetLastBlockMainSequence(msg *queue.Message) {
var lastSequence types.Int64
var err error
lastSequence.Data, err = chain.blockStore.LoadBlockLastMainSequence()
if err != nil {
chainlog.Debug("GetLastBlockMainSequence", "err", err)
msg.Reply(chain.client.NewMessage("rpc", types.EventReplyLastBlockMainSequence, err))
return
}
msg.Reply(chain.client.NewMessage("rpc", types.EventReplyLastBlockMainSequence, &lastSequence))
}
//GetMainSeqByHash parachian 通过blockhash获取对应的seq,只记录了addblock时的seq
func (chain *BlockChain) GetMainSeqByHash(msg *queue.Message) {
blockhash := (msg.Data).(*types.ReqHash)
seq, err := chain.ProcGetMainSeqByHash(blockhash.Hash)
if err != nil {
chainlog.Error("GetMainSeqByHash", "err", err.Error())
msg.Reply(chain.client.NewMessage("rpc", types.EventReplyMainSeqByHash, err))
return
}
msg.Reply(chain.client.NewMessage("rpc", types.EventReplyMainSeqByHash, &types.Int64{Data: seq}))
}
...@@ -402,7 +402,7 @@ func (b *BlockChain) connectBlock(node *blockNode, blockdetail *types.BlockDetai ...@@ -402,7 +402,7 @@ func (b *BlockChain) connectBlock(node *blockNode, blockdetail *types.BlockDetai
} }
} }
//目前非平行链并开启isRecordBlockSequence功能 //目前非平行链并开启isRecordBlockSequence功能
if b.isRecordBlockSequence && !b.isParaChain { if b.isRecordBlockSequence {
b.pushseq.updateSeq(lastSequence) b.pushseq.updateSeq(lastSequence)
} }
return blockdetail, nil return blockdetail, nil
...@@ -471,7 +471,7 @@ func (b *BlockChain) disconnectBlock(node *blockNode, blockdetail *types.BlockDe ...@@ -471,7 +471,7 @@ func (b *BlockChain) disconnectBlock(node *blockNode, blockdetail *types.BlockDe
chainlog.Debug("disconnectBlock success", "newtipnode.hash", common.ToHex(newtipnode.hash), "delblock.parent.hash", common.ToHex(blockdetail.Block.GetParentHash())) chainlog.Debug("disconnectBlock success", "newtipnode.hash", common.ToHex(newtipnode.hash), "delblock.parent.hash", common.ToHex(blockdetail.Block.GetParentHash()))
//目前非平行链并开启isRecordBlockSequence功能 //目前非平行链并开启isRecordBlockSequence功能
if b.isRecordBlockSequence && !b.isParaChain { if b.isRecordBlockSequence {
b.pushseq.updateSeq(lastSequence) b.pushseq.updateSeq(lastSequence)
} }
return nil return nil
......
...@@ -90,6 +90,18 @@ func (chain *BlockChain) ProcGetSeqByHash(hash []byte) (int64, error) { ...@@ -90,6 +90,18 @@ func (chain *BlockChain) ProcGetSeqByHash(hash []byte) (int64, error) {
return seq, err return seq, err
} }
//ProcGetMainSeqByHash 处理共识过来的通过blockhash获取seq的消息,只提供add block时的seq,用于平行链block回退
func (chain *BlockChain) ProcGetMainSeqByHash(hash []byte) (int64, error) {
if len(hash) == 0 {
chainlog.Error("ProcGetMainSeqByHash input hash is null")
return -1, types.ErrInvalidParam
}
seq, err := chain.blockStore.GetMainSequenceByHash(hash)
chainlog.Debug("ProcGetMainSeqByHash", "blockhash", common.ToHex(hash), "seq", seq, "err", err)
return seq, err
}
//ProcAddBlockSeqCB 添加seq callback //ProcAddBlockSeqCB 添加seq callback
func (chain *BlockChain) ProcAddBlockSeqCB(cb *types.BlockSeqCB) error { func (chain *BlockChain) ProcAddBlockSeqCB(cb *types.BlockSeqCB) error {
if cb == nil { if cb == nil {
......
...@@ -162,6 +162,18 @@ func (m *mockBlockChain) SetQueueClient(q queue.Queue) { ...@@ -162,6 +162,18 @@ func (m *mockBlockChain) SetQueueClient(q queue.Queue) {
} else { } else {
msg.ReplyErr("transaction id must 9999", types.ErrInvalidParam) msg.ReplyErr("transaction id must 9999", types.ErrInvalidParam)
} }
case types.EventGetMainSeqByHash:
if req, ok := msg.GetData().(*types.ReqHash); ok && string(req.Hash) == "exist-hash" {
msg.Reply(client.NewMessage(blockchainKey, types.EventReplyMainSeqByHash, &types.Int64{Data: 9999}))
} else {
msg.ReplyErr("transaction hash is not exist-hash", types.ErrInvalidParam)
}
case types.EventGetLastBlockMainSequence:
if _, ok := msg.GetData().(*types.ReqNil); ok {
msg.Reply(client.NewMessage(blockchainKey, types.EventReplyLastBlockMainSequence, &types.Int64{Data: 9999}))
} else {
msg.ReplyErr("request must be nil", types.ErrInvalidParam)
}
default: default:
msg.ReplyErr("Do not support", types.ErrNotSupport) msg.ReplyErr("Do not support", types.ErrNotSupport)
} }
......
...@@ -361,6 +361,29 @@ func (_m *QueueProtocolAPI) GetHeaders(param *types.ReqBlocks) (*types.Headers, ...@@ -361,6 +361,29 @@ func (_m *QueueProtocolAPI) GetHeaders(param *types.ReqBlocks) (*types.Headers,
return r0, r1 return r0, r1
} }
// GetLastBlockMainSequence provides a mock function with given fields:
func (_m *QueueProtocolAPI) GetLastBlockMainSequence() (*types.Int64, error) {
ret := _m.Called()
var r0 *types.Int64
if rf, ok := ret.Get(0).(func() *types.Int64); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.Int64)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetLastBlockSequence provides a mock function with given fields: // GetLastBlockSequence provides a mock function with given fields:
func (_m *QueueProtocolAPI) GetLastBlockSequence() (*types.Int64, error) { func (_m *QueueProtocolAPI) GetLastBlockSequence() (*types.Int64, error) {
ret := _m.Called() ret := _m.Called()
...@@ -430,6 +453,29 @@ func (_m *QueueProtocolAPI) GetLastMempool() (*types.ReplyTxList, error) { ...@@ -430,6 +453,29 @@ func (_m *QueueProtocolAPI) GetLastMempool() (*types.ReplyTxList, error) {
return r0, r1 return r0, r1
} }
// GetMainSequenceByHash provides a mock function with given fields: param
func (_m *QueueProtocolAPI) GetMainSequenceByHash(param *types.ReqHash) (*types.Int64, error) {
ret := _m.Called(param)
var r0 *types.Int64
if rf, ok := ret.Get(0).(func(*types.ReqHash) *types.Int64); ok {
r0 = rf(param)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.Int64)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*types.ReqHash) error); ok {
r1 = rf(param)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetMempool provides a mock function with given fields: // GetMempool provides a mock function with given fields:
func (_m *QueueProtocolAPI) GetMempool() (*types.ReplyTxList, error) { func (_m *QueueProtocolAPI) GetMempool() (*types.ReplyTxList, error) {
ret := _m.Called() ret := _m.Called()
......
...@@ -1139,3 +1139,36 @@ func (q *QueueProtocol) GetSeqCallBackLastNum(param *types.ReqString) (*types.In ...@@ -1139,3 +1139,36 @@ func (q *QueueProtocol) GetSeqCallBackLastNum(param *types.ReqString) (*types.In
} }
return nil, types.ErrTypeAsset return nil, types.ErrTypeAsset
} }
// GetLastBlockMainSequence 获取最新的block执行序列号
func (q *QueueProtocol) GetLastBlockMainSequence() (*types.Int64, error) {
msg, err := q.query(blockchainKey, types.EventGetLastBlockMainSequence, &types.ReqNil{})
if err != nil {
log.Error("GetLastBlockMainSequence", "Error", err.Error())
return nil, err
}
if reply, ok := msg.GetData().(*types.Int64); ok {
return reply, nil
}
return nil, types.ErrTypeAsset
}
// GetMainSequenceByHash 通过hash获取对应的执行序列号
func (q *QueueProtocol) GetMainSequenceByHash(param *types.ReqHash) (*types.Int64, error) {
if param == nil {
err := types.ErrInvalidParam
log.Error("GetMainSequenceByHash", "Error", err)
return nil, err
}
msg, err := q.query(blockchainKey, types.EventGetMainSeqByHash, param)
if err != nil {
log.Error("GetMainSequenceByHash", "Error", err.Error())
return nil, err
}
if reply, ok := msg.GetData().(*types.Int64); ok {
return reply, nil
}
return nil, types.ErrTypeAsset
}
...@@ -1196,3 +1196,26 @@ func TestGetBlockBySeq(t *testing.T) { ...@@ -1196,3 +1196,26 @@ func TestGetBlockBySeq(t *testing.T) {
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestGetMainSeq(t *testing.T) {
net := queue.New("test-seq-api")
defer net.Close()
chain := &mockBlockChain{}
chain.SetQueueClient(net)
defer chain.Close()
api, err := client.New(net.Client(), nil)
assert.Nil(t, err)
seq, err := api.GetMainSequenceByHash(&types.ReqHash{Hash: []byte("exist-hash")})
assert.Nil(t, err)
assert.Equal(t, int64(9999), seq.Data)
seq, err = api.GetMainSequenceByHash(&types.ReqHash{Hash: []byte("")})
assert.NotNil(t, err)
seq1, err := api.GetLastBlockMainSequence()
assert.Nil(t, err)
assert.Equal(t, int64(9999), seq1.Data)
}
...@@ -128,6 +128,12 @@ type QueueProtocolAPI interface { ...@@ -128,6 +128,12 @@ type QueueProtocolAPI interface {
//types.EventGetSequenceByHash: //types.EventGetSequenceByHash:
GetSequenceByHash(param *types.ReqHash) (*types.Int64, error) GetSequenceByHash(param *types.ReqHash) (*types.Int64, error)
// 在平行链上获得主链Sequence相关的接口
//types.EventGetLastBlockSequence:
GetLastBlockMainSequence() (*types.Int64, error)
//types.EventGetSequenceByHash:
GetMainSequenceByHash(param *types.ReqHash) (*types.Int64, error)
// --------------- blockchain interfaces end // --------------- blockchain interfaces end
// +++++++++++++++ store interfaces begin // +++++++++++++++ store interfaces begin
......
...@@ -59,7 +59,7 @@ func NewMempool(cfg *types.Mempool) *Mempool { ...@@ -59,7 +59,7 @@ func NewMempool(cfg *types.Mempool) *Mempool {
pool.cfg = cfg pool.cfg = cfg
pool.poolHeader = make(chan struct{}, 2) pool.poolHeader = make(chan struct{}, 2)
pool.removeBlockTicket = time.NewTicker(time.Minute) pool.removeBlockTicket = time.NewTicker(time.Minute)
pool.cache = newCache(cfg.MaxTxNumPerAccount, cfg.MaxTxLast) pool.cache = newCache(cfg.MaxTxNumPerAccount, cfg.MaxTxLast, cfg.PoolCacheSize)
return pool return pool
} }
...@@ -421,3 +421,26 @@ func (mem *Mempool) setSync(status bool) { ...@@ -421,3 +421,26 @@ func (mem *Mempool) setSync(status bool) {
mem.sync = status mem.sync = status
mem.proxyMtx.Unlock() mem.proxyMtx.Unlock()
} }
// getTxListByHash 从qcache或者SHashTxCache中获取hash对应的tx交易列表
func (mem *Mempool) getTxListByHash(hashList *types.ReqTxHashList) *types.ReplyTxList {
mem.proxyMtx.Lock()
defer mem.proxyMtx.Unlock()
var replyTxList types.ReplyTxList
//通过短hash来获取tx交易
if hashList.GetIsShortHash() {
for _, sHash := range hashList.GetHashes() {
tx := mem.cache.GetSHashTxCache(sHash)
replyTxList.Txs = append(replyTxList.Txs, tx)
}
return &replyTxList
}
//通过hash来获取tx交易
for _, hash := range hashList.GetHashes() {
tx := mem.cache.getTxByHash(hash)
replyTxList.Txs = append(replyTxList.Txs, tx)
}
return &replyTxList
}
...@@ -34,13 +34,15 @@ type txCache struct { ...@@ -34,13 +34,15 @@ type txCache struct {
qcache QueueCache qcache QueueCache
totalFee int64 totalFee int64
totalByte int64 totalByte int64
*SHashTxCache
} }
//NewTxCache init accountIndex and last cache //NewTxCache init accountIndex and last cache
func newCache(maxTxPerAccount int64, sizeLast int64) *txCache { func newCache(maxTxPerAccount int64, sizeLast int64, poolCacheSize int64) *txCache {
return &txCache{ return &txCache{
AccountTxIndex: NewAccountTxIndex(int(maxTxPerAccount)), AccountTxIndex: NewAccountTxIndex(int(maxTxPerAccount)),
LastTxCache: NewLastTxCache(int(sizeLast)), LastTxCache: NewLastTxCache(int(sizeLast)),
SHashTxCache: NewSHashTxCache(int(poolCacheSize)),
} }
} }
...@@ -64,6 +66,7 @@ func (cache *txCache) Remove(hash string) { ...@@ -64,6 +66,7 @@ func (cache *txCache) Remove(hash string) {
cache.LastTxCache.Remove(tx) cache.LastTxCache.Remove(tx)
cache.totalFee -= tx.Fee cache.totalFee -= tx.Fee
cache.totalByte -= int64(proto.Size(tx)) cache.totalByte -= int64(proto.Size(tx))
cache.SHashTxCache.Remove(tx)
} }
//Exist 是否存在 //Exist 是否存在
...@@ -132,6 +135,7 @@ func (cache *txCache) Push(tx *types.Transaction) error { ...@@ -132,6 +135,7 @@ func (cache *txCache) Push(tx *types.Transaction) error {
cache.LastTxCache.Push(tx) cache.LastTxCache.Push(tx)
cache.totalFee += tx.Fee cache.totalFee += tx.Fee
cache.totalByte += int64(proto.Size(tx)) cache.totalByte += int64(proto.Size(tx))
cache.SHashTxCache.Push(tx)
return nil return nil
} }
...@@ -156,3 +160,12 @@ func isExpired(item *Item, height, blockTime int64) bool { ...@@ -156,3 +160,12 @@ func isExpired(item *Item, height, blockTime int64) bool {
} }
return false return false
} }
//getTxByHash 通过交易hash获取tx交易信息
func (cache *txCache) getTxByHash(hash string) *types.Transaction {
item, err := cache.qcache.GetItem(hash)
if err != nil {
return nil
}
return item.Value
}
...@@ -90,6 +90,9 @@ func (mem *Mempool) eventProcess() { ...@@ -90,6 +90,9 @@ func (mem *Mempool) eventProcess() {
case types.EventGetProperFee: case types.EventGetProperFee:
// 获取对应排队策略中合适的手续费 // 获取对应排队策略中合适的手续费
mem.eventGetProperFee(msg) mem.eventGetProperFee(msg)
// 消息类型EventTxListByHash:通过hash获取对应的tx列表
case types.EventTxListByHash:
mem.eventTxListByHash(msg)
default: default:
} }
mlog.Debug("mempool", "cost", types.Since(beg), "msg", types.GetEventName(int(msg.Ty))) mlog.Debug("mempool", "cost", types.Since(beg), "msg", types.GetEventName(int(msg.Ty)))
...@@ -205,3 +208,10 @@ func (mem *Mempool) checkSign(data *queue.Message) *queue.Message { ...@@ -205,3 +208,10 @@ func (mem *Mempool) checkSign(data *queue.Message) *queue.Message {
data.Data = types.ErrSign data.Data = types.ErrSign
return data return data
} }
// eventTxListByHash 通过hash获取tx列表
func (mem *Mempool) eventTxListByHash(msg *queue.Message) {
shashList := msg.GetData().(*types.ReqTxHashList)
replytxList := mem.getTxListByHash(shashList)
msg.Reply(mem.client.NewMessage("", types.EventReplyTxList, replytxList))
}
...@@ -1144,3 +1144,94 @@ func execProcess(q queue.Queue) { ...@@ -1144,3 +1144,94 @@ func execProcess(q queue.Queue) {
} }
}() }()
} }
func TestTx(t *testing.T) {
subConfig := SubConfig{10240, 10000}
cache := newCache(10240, 10, 10240)
cache.SetQueueCache(NewSimpleQueue(subConfig))
tx := &types.Transaction{Execer: []byte("user.write"), Payload: types.Encode(transfer), Fee: 100000000, Expire: 0, To: toAddr}
var replyTxList types.ReplyTxList
var sHastList types.ReqTxHashList
var hastList types.ReqTxHashList
for i := 1; i <= 10240; i++ {
tx.Expire = int64(i)
cache.Push(tx)
sHastList.Hashes = append(sHastList.Hashes, types.CalcTxShortHash(tx.Hash()))
hastList.Hashes = append(hastList.Hashes, string(tx.Hash()))
}
for i := 1; i <= 1600; i++ {
Tx := cache.GetSHashTxCache(sHastList.Hashes[i])
if Tx == nil {
panic("TestTx:GetSHashTxCache is nil")
}
replyTxList.Txs = append(replyTxList.Txs, Tx)
}
for i := 1; i <= 1600; i++ {
Tx := cache.getTxByHash(hastList.Hashes[i])
if Tx == nil {
panic("TestTx:getTxByHash is nil")
}
replyTxList.Txs = append(replyTxList.Txs, Tx)
}
}
func TestEventTxListByHash(t *testing.T) {
q, mem := initEnv(0)
defer q.Close()
defer mem.Close()
// add tx
hashes, err := add4TxHash(mem.client)
if err != nil {
t.Error("add tx error", err.Error())
return
}
//通过交易hash获取交易信息
reqTxHashList := types.ReqTxHashList{
Hashes: hashes,
IsShortHash: false,
}
msg1 := mem.client.NewMessage("mempool", types.EventTxListByHash, &reqTxHashList)
mem.client.Send(msg1, true)
data1, err := mem.client.Wait(msg1)
if err != nil {
t.Error(err)
return
}
txs1 := data1.GetData().(*types.ReplyTxList).GetTxs()
if len(txs1) != 4 {
t.Error("TestEventTxListByHash:get txlist number error")
}
for i, tx := range txs1 {
if hashes[i] != string(tx.Hash()) {
t.Error("TestEventTxListByHash:hash mismatch")
}
}
//通过短hash获取tx交易
var shashes []string
for _, hash := range hashes {
shashes = append(shashes, types.CalcTxShortHash([]byte(hash)))
}
reqTxHashList.Hashes = shashes
reqTxHashList.IsShortHash = true
msg2 := mem.client.NewMessage("mempool", types.EventTxListByHash, &reqTxHashList)
mem.client.Send(msg2, true)
data2, err := mem.client.Wait(msg2)
if err != nil {
t.Error(err)
return
}
txs2 := data2.GetData().(*types.ReplyTxList).GetTxs()
for i, tx := range txs2 {
if hashes[i] != string(tx.Hash()) {
t.Error("TestEventTxListByHash:shash mismatch")
}
}
}
package mempool
import (
"github.com/33cn/chain33/common"
"github.com/33cn/chain33/common/listmap"
log "github.com/33cn/chain33/common/log/log15"
"github.com/33cn/chain33/types"
)
var shashlog = log.New("module", "mempool.shash")
//SHashTxCache 通过shorthash缓存交易
type SHashTxCache struct {
max int
l *listmap.ListMap
}
//NewSHashTxCache 创建通过短hash交易的cache
func NewSHashTxCache(size int) *SHashTxCache {
return &SHashTxCache{
max: size,
l: listmap.New(),
}
}
//GetSHashTxCache 返回shorthash对应的tx交易信息
func (cache *SHashTxCache) GetSHashTxCache(sHash string) *types.Transaction {
tx, err := cache.l.GetItem(sHash)
if err != nil {
return nil
}
return tx.(*types.Transaction)
}
//Remove remove tx of SHashTxCache
func (cache *SHashTxCache) Remove(tx *types.Transaction) {
txhash := tx.Hash()
cache.l.Remove(types.CalcTxShortHash(txhash))
//shashlog.Debug("SHashTxCache:Remove", "shash", types.CalcTxShortHash(txhash), "txhash", common.ToHex(txhash))
}
//Push tx into SHashTxCache
func (cache *SHashTxCache) Push(tx *types.Transaction) {
shash := types.CalcTxShortHash(tx.Hash())
if cache.Exist(shash) {
shashlog.Error("SHashTxCache:Push:Exist", "oldhash", common.ToHex(cache.GetSHashTxCache(shash).Hash()), "newhash", common.ToHex(tx.Hash()))
return
}
if cache.l.Size() >= cache.max {
shashlog.Error("SHashTxCache:Push:ErrMemFull", "cache.l.Size()", cache.l.Size(), "cache.max", cache.max)
return
}
cache.l.Push(shash, tx)
//shashlog.Debug("SHashTxCache:Push", "shash", shash, "txhash", common.ToHex(tx.Hash()))
}
//Exist 是否存在
func (cache *SHashTxCache) Exist(shash string) bool {
return cache.l.Exist(shash)
}
...@@ -152,10 +152,16 @@ const ( ...@@ -152,10 +152,16 @@ const (
EventReplyProperFee = 141 EventReplyProperFee = 141
EventReExecBlock = 142 EventReExecBlock = 142
EventTxListByHash = 143
//exec //exec
EventBlockChainQuery = 212 EventBlockChainQuery = 212
EventConsensusQuery = 213 EventConsensusQuery = 213
// BlockChain 接收的事件
EventGetLastBlockMainSequence = 300
EventReplyLastBlockMainSequence = 301
EventGetMainSeqByHash = 302
EventReplyMainSeqByHash = 303
) )
var eventName = map[int]string{ var eventName = map[int]string{
...@@ -301,4 +307,10 @@ var eventName = map[int]string{ ...@@ -301,4 +307,10 @@ var eventName = map[int]string{
//mempool //mempool
EventGetProperFee: "EventGetProperFee", EventGetProperFee: "EventGetProperFee",
EventReplyProperFee: "EventReplyProperFee", EventReplyProperFee: "EventReplyProperFee",
EventTxListByHash: "EventTxListByHash",
// block chain
EventGetLastBlockMainSequence: "EventGetLastBlockMainSequence",
EventReplyLastBlockMainSequence: "EventReplyLastBlockMainSequence",
EventGetMainSeqByHash: "EventGetMainSeqByHash",
EventReplyMainSeqByHash: "EventReplyMainSeqByHash",
} }
...@@ -242,3 +242,9 @@ message UpgradeMeta { ...@@ -242,3 +242,9 @@ message UpgradeMeta {
string version = 2; string version = 2;
int64 height = 3; int64 height = 3;
} }
//通过交易hash获取交易列表,需要区分是短hash还是全hash值
message ReqTxHashList {
repeated string hashes = 1;
bool isShortHash = 2;
}
...@@ -778,3 +778,11 @@ func ParseExpire(expire string) (int64, error) { ...@@ -778,3 +778,11 @@ func ParseExpire(expire string) (int64, error) {
return 0, err return 0, err
} }
//CalcTxShortHash 取txhash的前指定字节,目前默认5
func CalcTxShortHash(hash []byte) string {
if len(hash) >= 5 {
return hex.EncodeToString(hash[0:5])
}
return ""
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment