Commit e892e9e6 authored by 陈德海's avatar 陈德海

convert skiplist to chain33

parent 86ea79a1
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
"github.com/33cn/chain33/common/skiplist"
"github.com/33cn/chain33/system/mempool" "github.com/33cn/chain33/system/mempool"
"github.com/33cn/chain33/types" "github.com/33cn/chain33/types"
) )
...@@ -12,21 +13,21 @@ var mempoolDupResendInterval int64 = 600 // mempool内交易过期时间,10分 ...@@ -12,21 +13,21 @@ var mempoolDupResendInterval int64 = 600 // mempool内交易过期时间,10分
// Queue 价格队列模式(价格=手续费/交易字节数,价格高者优先,同价则时间早优先) // Queue 价格队列模式(价格=手续费/交易字节数,价格高者优先,同价则时间早优先)
type Queue struct { type Queue struct {
txMap map[string]*SkipValue txMap map[string]*skiplist.SkipValue
txList *SkipList txList *skiplist.SkipList
subConfig subConfig subConfig subConfig
} }
// NewQueue 创建队列 // NewQueue 创建队列
func NewQueue(subcfg subConfig) *Queue { func NewQueue(subcfg subConfig) *Queue {
return &Queue{ return &Queue{
txMap: make(map[string]*SkipValue, subcfg.PoolCacheSize), txMap: make(map[string]*skiplist.SkipValue, subcfg.PoolCacheSize),
txList: NewSkipList(&SkipValue{-1, nil}), txList: skiplist.NewSkipList(&skiplist.SkipValue{-1, nil}),
subConfig: subcfg, subConfig: subcfg,
} }
} }
func (cache *Queue) newSkipValue(item *mempool.Item) (*SkipValue, error) { func (cache *Queue) SkipValue(item *mempool.Item) (*skiplist.SkipValue, error) {
//tx := item.value //tx := item.value
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf) enc := gob.NewEncoder(buf)
...@@ -35,7 +36,7 @@ func (cache *Queue) newSkipValue(item *mempool.Item) (*SkipValue, error) { ...@@ -35,7 +36,7 @@ func (cache *Queue) newSkipValue(item *mempool.Item) (*SkipValue, error) {
return nil, err return nil, err
} }
size := len(buf.Bytes()) size := len(buf.Bytes())
return &SkipValue{Price: item.Value.Fee / int64(size), Value: item}, nil return &skiplist.SkipValue{Score: item.Value.Fee / int64(size), Value: item}, nil
} }
//Exist 是否存在 //Exist 是否存在
...@@ -67,7 +68,7 @@ func (cache *Queue) Push(item *mempool.Item) error { ...@@ -67,7 +68,7 @@ func (cache *Queue) Push(item *mempool.Item) error {
newEnterTime := types.Now().Unix() newEnterTime := types.Now().Unix()
resendItem := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: newEnterTime} resendItem := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: newEnterTime}
var err error var err error
sv, err := cache.newSkipValue(resendItem) sv, err := cache.SkipValue(resendItem)
if err != nil { if err != nil {
return err return err
} }
...@@ -79,16 +80,25 @@ func (cache *Queue) Push(item *mempool.Item) error { ...@@ -79,16 +80,25 @@ func (cache *Queue) Push(item *mempool.Item) error {
} }
it := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: item.EnterTime} it := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: item.EnterTime}
sv, err := cache.newSkipValue(it) sv, err := cache.SkipValue(it)
if err != nil { if err != nil {
return err return err
} }
if int64(cache.txList.Len()) >= cache.subConfig.PoolCacheSize { if int64(cache.txList.Len()) >= cache.subConfig.PoolCacheSize {
tail := cache.txList.GetIterator().Last() tail := cache.txList.GetIterator().Last()
//价格高存留 //价格高存留
if sv.Compare(tail) == -1 { switch sv.Compare(tail) {
case -1:
cache.Remove(string(tail.Value.(*mempool.Item).Value.Hash())) cache.Remove(string(tail.Value.(*mempool.Item).Value.Hash()))
} else { case 0:
if sv.Value.(*mempool.Item).EnterTime < tail.Value.(*mempool.Item).EnterTime {
cache.Remove(string(tail.Value.(*mempool.Item).Value.Hash()))
break
}
return types.ErrMemFull
case 1:
return types.ErrMemFull
default:
return types.ErrMemFull return types.ErrMemFull
} }
} }
......
...@@ -134,5 +134,5 @@ func TestQueueDirection(t *testing.T) { ...@@ -134,5 +134,5 @@ func TestQueueDirection(t *testing.T) {
cache.Push(item4) cache.Push(item4)
cache.Push(item5) cache.Push(item5)
cache.txList.Print() cache.txList.Print()
assert.Equal(t, true, cache.txList.GetIterator().First().Price >= cache.txList.GetIterator().Last().Price) assert.Equal(t, true, cache.txList.GetIterator().First().Score >= cache.txList.GetIterator().Last().Score)
} }
package price
import (
"fmt"
"math/rand"
"github.com/33cn/chain33/system/mempool"
)
const maxLevel = 32
const prob = 0.35
// SkipValue 跳跃表节点
type SkipValue struct {
Price int64
Value interface{}
}
// Compare 比较函数
func (v *SkipValue) Compare(value *SkipValue) int {
if v.Price > value.Price {
return -1
} else if v.Price == value.Price {
if v.Value.(*mempool.Item).EnterTime < value.Value.(*mempool.Item).EnterTime {
return -1
}
return 0
}
return 1
}
type skipListNode struct {
next []*skipListNode
prev *skipListNode
Value *SkipValue
}
// SkipList 跳跃表
type SkipList struct {
header, tail *skipListNode
findcount int
count int
level int
}
// SkipListIterator 跳跃表迭代器
type SkipListIterator struct {
list *SkipList
node *skipListNode
}
// First 获取第一个节点
func (sli *SkipListIterator) First() *SkipValue {
if sli.list.header.next[0] == nil {
return nil
}
sli.node = sli.list.header.next[0]
return sli.node.Value
}
// Last 获取最后一个节点
func (sli *SkipListIterator) Last() *SkipValue {
if sli.list.tail == nil {
return nil
}
sli.node = sli.list.tail
return sli.node.Value
}
// Prev 获取上一个节点
func (node *skipListNode) Prev() *skipListNode {
if node == nil || node.prev == nil {
return nil
}
return node.prev
}
// Next 获取下一个节点
func (node *skipListNode) Next() *skipListNode {
if node == nil || node.next[0] == nil {
return nil
}
return node.next[0]
}
func newskipListNode(level int, value *SkipValue) *skipListNode {
node := &skipListNode{}
node.next = make([]*skipListNode, level)
node.Value = value
return node
}
//NewSkipList 构建一个value的最小值
func NewSkipList(min *SkipValue) *SkipList {
sl := &SkipList{}
sl.level = 1
sl.header = newskipListNode(maxLevel, min)
return sl
}
func randomLevel() int {
level := 1
t := prob * 0xFFFF
for rand.Int()&0xFFFF < int(t) {
level++
if level == maxLevel {
break
}
}
return level
}
// GetIterator 返回第一个节点
func (sl *SkipList) GetIterator() *SkipListIterator {
it := &SkipListIterator{}
it.list = sl
it.First()
return it
}
// Len 返回节点数
func (sl *SkipList) Len() int {
return sl.count
}
func (sl *SkipList) find(value *SkipValue) *skipListNode {
x := sl.header
for i := sl.level - 1; i >= 0; i-- {
for x.next[i] != nil && x.next[i].Value.Compare(value) < 0 {
sl.findcount++
x = x.next[i]
}
}
return x
}
// FindCount 返回查询次数
func (sl *SkipList) FindCount() int {
return sl.findcount
}
// Find 查找skipvalue
func (sl *SkipList) Find(value *SkipValue) *SkipValue {
x := sl.find(value)
if x.next[0] != nil && x.next[0].Value.Compare(value) == 0 {
return x.next[0].Value
}
return nil
}
// Insert 插入节点
func (sl *SkipList) Insert(value *SkipValue) int {
var update [maxLevel]*skipListNode
x := sl.header
for i := sl.level - 1; i >= 0; i-- {
for x.next[i] != nil && x.next[i].Value.Compare(value) <= 0 {
x = x.next[i]
}
update[i] = x
}
//if x.next[0] != nil && x.next[0].Value.Compare(value) == 0 { //update
// x.next[0].Value = value
// return 0
//}
level := randomLevel()
if level > sl.level {
for i := sl.level; i < level; i++ {
update[i] = sl.header
}
sl.level = level
}
x = newskipListNode(level, value)
for i := 0; i < level; i++ {
x.next[i] = update[i].next[i]
update[i].next[i] = x
}
//形成一个双向链表
if update[0] != sl.header {
x.prev = update[0]
}
if x.next[0] != nil {
x.next[0].prev = x
} else {
sl.tail = x
}
sl.count++
return 1
}
// Delete 删除节点
func (sl *SkipList) Delete(value *SkipValue) int {
var update [maxLevel]*skipListNode
x := sl.header
for i := sl.level - 1; i >= 0; i-- {
for x.next[i] != nil && x.next[i].Value.Compare(value) < 0 {
x = x.next[i]
}
update[i] = x
}
if x.next[0] == nil || x.next[0].Value.Compare(value) != 0 { //not find
return 0
}
x = x.next[0]
for i := 0; i < sl.level; i++ {
if update[i].next[i] == x {
update[i].next[i] = x.next[i]
}
}
if x.next[0] != nil {
x.next[0].prev = x.prev
} else {
sl.tail = x.prev
}
for sl.level > 1 && sl.header.next[sl.level-1] == nil {
sl.level--
}
sl.count--
return 1
}
// Print 测试用的输出函数
func (sl *SkipList) Print() {
if sl.count > 0 {
r := sl.header
for i := sl.level - 1; i >= 0; i-- {
e := r.next[i]
//fmt.Print(i)
for e != nil {
fmt.Print(e.Value.Price)
fmt.Print(" ")
fmt.Print(e.Value)
fmt.Println("")
e = e.next[i]
}
fmt.Println()
}
} else {
fmt.Println("空")
}
}
//Walk 遍历整个结构,如果cb 返回false 那么停止遍历
func (sl *SkipList) Walk(cb func(value interface{}) bool) {
for e := sl.header.Next(); e != nil; e = e.Next() {
if cb == nil {
return
}
if !cb(e.Value.Value) {
return
}
}
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
"github.com/33cn/chain33/common/skiplist"
"github.com/33cn/chain33/system/mempool" "github.com/33cn/chain33/system/mempool"
"github.com/33cn/chain33/types" "github.com/33cn/chain33/types"
) )
...@@ -12,21 +13,21 @@ var mempoolDupResendInterval int64 = 600 // mempool内交易过期时间,10分 ...@@ -12,21 +13,21 @@ var mempoolDupResendInterval int64 = 600 // mempool内交易过期时间,10分
// Queue 分数队列模式(分数=常量a*手续费/交易字节数-常量b*时间*定量c,按分数排队,高的优先,常量a,b和定量c可配置) // Queue 分数队列模式(分数=常量a*手续费/交易字节数-常量b*时间*定量c,按分数排队,高的优先,常量a,b和定量c可配置)
type Queue struct { type Queue struct {
txMap map[string]*SkipValue txMap map[string]*skiplist.SkipValue
txList *SkipList txList *skiplist.SkipList
subConfig subConfig subConfig subConfig
} }
// NewQueue 创建队列 // NewQueue 创建队列
func NewQueue(subcfg subConfig) *Queue { func NewQueue(subcfg subConfig) *Queue {
return &Queue{ return &Queue{
txMap: make(map[string]*SkipValue, subcfg.PoolCacheSize), txMap: make(map[string]*skiplist.SkipValue, subcfg.PoolCacheSize),
txList: NewSkipList(&SkipValue{-1, nil}), txList: skiplist.NewSkipList(&skiplist.SkipValue{-1, nil}),
subConfig: subcfg, subConfig: subcfg,
} }
} }
func (cache *Queue) newSkipValue(item *mempool.Item) (*SkipValue, error) { func (cache *Queue) SkipValue(item *mempool.Item) (*skiplist.SkipValue, error) {
//tx := item.value //tx := item.value
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf) enc := gob.NewEncoder(buf)
...@@ -35,7 +36,7 @@ func (cache *Queue) newSkipValue(item *mempool.Item) (*SkipValue, error) { ...@@ -35,7 +36,7 @@ func (cache *Queue) newSkipValue(item *mempool.Item) (*SkipValue, error) {
return nil, err return nil, err
} }
size := len(buf.Bytes()) size := len(buf.Bytes())
return &SkipValue{Score: cache.subConfig.PriceConstant*(item.Value.Fee/int64(size))*cache.subConfig.PricePower - cache.subConfig.TimeParam*item.EnterTime, Value: item}, nil return &skiplist.SkipValue{Score: cache.subConfig.PriceConstant*(item.Value.Fee/int64(size))*cache.subConfig.PricePower - cache.subConfig.TimeParam*item.EnterTime, Value: item}, nil
} }
// Exist 是否存在 // Exist 是否存在
...@@ -68,7 +69,7 @@ func (cache *Queue) Push(item *mempool.Item) error { ...@@ -68,7 +69,7 @@ func (cache *Queue) Push(item *mempool.Item) error {
newEnterTime := types.Now().Unix() newEnterTime := types.Now().Unix()
resendItem := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: newEnterTime} resendItem := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: newEnterTime}
var err error var err error
sv, err := cache.newSkipValue(resendItem) sv, err := cache.SkipValue(resendItem)
if err != nil { if err != nil {
return err return err
} }
...@@ -80,7 +81,7 @@ func (cache *Queue) Push(item *mempool.Item) error { ...@@ -80,7 +81,7 @@ func (cache *Queue) Push(item *mempool.Item) error {
} }
it := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: item.EnterTime} it := &mempool.Item{Value: item.Value, Priority: item.Value.Fee, EnterTime: item.EnterTime}
sv, err := cache.newSkipValue(it) sv, err := cache.SkipValue(it)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -376,8 +376,8 @@ func init() { ...@@ -376,8 +376,8 @@ func init() {
import ( import (
log "github.com/inconshreveable/log15" log "github.com/inconshreveable/log15"
drivers "gitlab.33.cn/chain33/chain33/system/dapp" drivers "github.com/33cn/chain33/system/dapp"
"gitlab.33.cn/chain33/chain33/types" "github.com/33cn/chain33/types"
) )
var clog = log.New("module", "execs.${EXECNAME}") var clog = log.New("module", "execs.${EXECNAME}")
...@@ -453,7 +453,7 @@ message ${ACTIONNAME}None { ...@@ -453,7 +453,7 @@ message ${ACTIONNAME}None {
CpftDappTypefile = `package types CpftDappTypefile = `package types
import ( import (
"gitlab.33.cn/chain33/chain33/types" "github.com/33cn/chain33/types"
) )
var ( var (
......
package score package skiplist
import ( import (
"fmt" "fmt"
...@@ -8,24 +8,23 @@ import ( ...@@ -8,24 +8,23 @@ import (
const maxLevel = 32 const maxLevel = 32
const prob = 0.35 const prob = 0.35
// SkipValue 跳跃表节点 // SkipValue 跳跃表节点的Value值
type SkipValue struct { type SkipValue struct {
Score int64 Score int64
Value interface{} Value interface{}
} }
// Compare 比较函数 // Compare 比较函数,这样的比较排序是从大到小
func (v *SkipValue) Compare(value *SkipValue) int { func (v *SkipValue) Compare(value *SkipValue) int {
f1 := v.Score if v.Score > value.Score {
f2 := value.Score
if f1 > f2 {
return -1 return -1
} else if f1 == f2 { } else if v.Score == value.Score {
return 0 return 0
} }
return 1 return 1
} }
// skipListNode 跳跃表节点
type skipListNode struct { type skipListNode struct {
next []*skipListNode next []*skipListNode
prev *skipListNode prev *skipListNode
...@@ -46,7 +45,7 @@ type SkipListIterator struct { ...@@ -46,7 +45,7 @@ type SkipListIterator struct {
node *skipListNode node *skipListNode
} }
// First 获取第一个节点 // First 获取第一个节点Value值
func (sli *SkipListIterator) First() *SkipValue { func (sli *SkipListIterator) First() *SkipValue {
if sli.list.header.next[0] == nil { if sli.list.header.next[0] == nil {
return nil return nil
...@@ -55,7 +54,7 @@ func (sli *SkipListIterator) First() *SkipValue { ...@@ -55,7 +54,7 @@ func (sli *SkipListIterator) First() *SkipValue {
return sli.node.Value return sli.node.Value
} }
// Last 获取最后一个节点 // Last 获取最后一个节点Value值
func (sli *SkipListIterator) Last() *SkipValue { func (sli *SkipListIterator) Last() *SkipValue {
if sli.list.tail == nil { if sli.list.tail == nil {
return nil return nil
...@@ -80,6 +79,16 @@ func (node *skipListNode) Next() *skipListNode { ...@@ -80,6 +79,16 @@ func (node *skipListNode) Next() *skipListNode {
return node.next[0] return node.next[0]
} }
// Seek 迭代器在跳跃表中查找某个位置在传参后面或者与传参相等的SkipValue
func (sli *SkipListIterator) Seek(value *SkipValue) *SkipValue {
x := sli.list.find(value)
if x.next[0] == nil {
return nil
}
sli.node = x.next[0]
return sli.node.Value
}
func newskipListNode(level int, value *SkipValue) *skipListNode { func newskipListNode(level int, value *SkipValue) *skipListNode {
node := &skipListNode{} node := &skipListNode{}
node.next = make([]*skipListNode, level) node.next = make([]*skipListNode, level)
...@@ -87,7 +96,7 @@ func newskipListNode(level int, value *SkipValue) *skipListNode { ...@@ -87,7 +96,7 @@ func newskipListNode(level int, value *SkipValue) *skipListNode {
return node return node
} }
//NewSkipList 构建一个value的最小值 //NewSkipList 构建一个跳跃表
func NewSkipList(min *SkipValue) *SkipList { func NewSkipList(min *SkipValue) *SkipList {
sl := &SkipList{} sl := &SkipList{}
sl.level = 1 sl.level = 1
...@@ -107,7 +116,7 @@ func randomLevel() int { ...@@ -107,7 +116,7 @@ func randomLevel() int {
return level return level
} }
// GetIterator 返回第一个节点 // GetIterator 获取迭代器
func (sl *SkipList) GetIterator() *SkipListIterator { func (sl *SkipList) GetIterator() *SkipListIterator {
it := &SkipListIterator{} it := &SkipListIterator{}
it.list = sl it.list = sl
...@@ -120,6 +129,11 @@ func (sl *SkipList) Len() int { ...@@ -120,6 +129,11 @@ func (sl *SkipList) Len() int {
return sl.count return sl.count
} }
// Level 返回跳跃表的层级
func (sl *SkipList) Level() int {
return sl.level
}
func (sl *SkipList) find(value *SkipValue) *skipListNode { func (sl *SkipList) find(value *SkipValue) *skipListNode {
x := sl.header x := sl.header
for i := sl.level - 1; i >= 0; i-- { for i := sl.level - 1; i >= 0; i-- {
...@@ -136,7 +150,7 @@ func (sl *SkipList) FindCount() int { ...@@ -136,7 +150,7 @@ func (sl *SkipList) FindCount() int {
return sl.findcount return sl.findcount
} }
// Find 查找skipvalue // Find 查找某个跳跃表中的SkipValue
func (sl *SkipList) Find(value *SkipValue) *SkipValue { func (sl *SkipList) Find(value *SkipValue) *SkipValue {
x := sl.find(value) x := sl.find(value)
if x.next[0] != nil && x.next[0].Value.Compare(value) == 0 { if x.next[0] != nil && x.next[0].Value.Compare(value) == 0 {
...@@ -145,6 +159,15 @@ func (sl *SkipList) Find(value *SkipValue) *SkipValue { ...@@ -145,6 +159,15 @@ func (sl *SkipList) Find(value *SkipValue) *SkipValue {
return nil return nil
} }
// FindGreaterOrEqual 在跳跃表中查找某个位置在传参后面或者与传参相等的SkipValue
func (sl *SkipList) FindGreaterOrEqual(value *SkipValue) *SkipValue {
x := sl.find(value)
if x.next[0] != nil {
return x.next[0].Value
}
return nil
}
// Insert 插入节点 // Insert 插入节点
func (sl *SkipList) Insert(value *SkipValue) int { func (sl *SkipList) Insert(value *SkipValue) int {
var update [maxLevel]*skipListNode var update [maxLevel]*skipListNode
......
package skiplist
import (
"testing"
"github.com/stretchr/testify/assert"
)
var(
s1=&SkipValue{1,"111"}
s2=&SkipValue{2,"222"}
s3=&SkipValue{3,"333"}
s4=&SkipValue{4,"444"}
)
func TestInsert(t *testing.T) {
l := NewSkipList(nil)
l.Insert(s1)
assert.Equal(t, 1, l.Len())
l.Insert(s2)
assert.Equal(t, 2, l.Len())
iter := l.GetIterator()
assert.Equal(t, int64(2), iter.First().Score)
assert.Equal(t, "222", iter.First().Value.(string))
assert.Equal(t, int64(1), iter.Last().Score)
assert.Equal(t, "111", iter.Last().Value.(string))
}
func TestFind(t *testing.T) {
l := NewSkipList(nil)
l.Insert(s1)
assert.Equal(t, s1, l.Find(s1))
l.Insert(s2)
assert.Equal(t, s2, l.Find(s2))
l.Insert(s3)
assert.Equal(t, s3, l.Find(s3))
}
func TestDelete(t *testing.T) {
l := NewSkipList(nil)
l.Insert(s1)
l.Insert(s2)
l.Delete(s1)
assert.Equal(t, 1, l.Len())
assert.Equal(t, (*SkipValue)(nil), l.Find(s1))
assert.Equal(t, s2, l.Find(s2))
}
func TestWalk(t *testing.T) {
l := NewSkipList(nil)
l.Insert(s1)
l.Insert(s2)
var data [2]string
i := 0
l.Walk(func(value interface{}) bool {
data[i] = value.(string)
i++
return true
})
assert.Equal(t, data[0], "222")
assert.Equal(t, data[1], "111")
var data2 [2]string
i = 0
l.Walk(func(value interface{}) bool {
data2[i] = value.(string)
i++
return false
})
assert.Equal(t, data2[0], "222")
assert.Equal(t, data2[1], "")
l.Walk(nil)
iter := l.GetIterator()
assert.Equal(t, int64(2), iter.First().Score)
assert.Equal(t, "222", iter.First().Value.(string))
assert.Equal(t, int64(1), iter.Last().Score)
assert.Equal(t, "111", iter.Last().Value.(string))
}
...@@ -6,7 +6,6 @@ package executor ...@@ -6,7 +6,6 @@ package executor
import ( import (
"bytes" "bytes"
"runtime/debug"
drivers "github.com/33cn/chain33/system/dapp" drivers "github.com/33cn/chain33/system/dapp"
"github.com/33cn/chain33/types" "github.com/33cn/chain33/types"
...@@ -62,27 +61,35 @@ func isAllowKeyWrite(key, realExecer []byte, tx *types.Transaction, height int64 ...@@ -62,27 +61,35 @@ func isAllowKeyWrite(key, realExecer []byte, tx *types.Transaction, height int64
} }
func isAllowLocalKey(execer []byte, key []byte) error { func isAllowLocalKey(execer []byte, key []byte) error {
execer = types.GetRealExecName(execer) if err := isAllowLocalKey2(execer, key); err != nil {
//println(string(execer), string(key)) realexec := types.GetRealExecName(execer)
if bytes.Equal(realexec, execer) {
return err
}
return isAllowLocalKey2(realexec, key)
}
return nil
}
func isAllowLocalKey2(execer []byte, key []byte) error {
if len(execer) < 1 {
return types.ErrLocalPrefix
}
minkeylen := len(types.LocalPrefix) + len(execer) + 2 minkeylen := len(types.LocalPrefix) + len(execer) + 2
if len(key) <= minkeylen { if len(key) <= minkeylen {
debug.PrintStack()
elog.Error("isAllowLocalKey too short", "key", string(key), "exec", string(execer)) elog.Error("isAllowLocalKey too short", "key", string(key), "exec", string(execer))
return types.ErrLocalKeyLen return types.ErrLocalKeyLen
} }
if key[minkeylen-1] != '-' { if key[minkeylen-1] != '-' {
debug.PrintStack()
elog.Error("isAllowLocalKey prefix last char is not '-'", "key", string(key), "exec", string(execer), elog.Error("isAllowLocalKey prefix last char is not '-'", "key", string(key), "exec", string(execer),
"minkeylen", minkeylen) "minkeylen", minkeylen)
return types.ErrLocalPrefix return types.ErrLocalPrefix
} }
if !bytes.HasPrefix(key, types.LocalPrefix) { if !bytes.HasPrefix(key, types.LocalPrefix) {
debug.PrintStack()
elog.Error("isAllowLocalKey common prefix not match", "key", string(key), "exec", string(execer)) elog.Error("isAllowLocalKey common prefix not match", "key", string(key), "exec", string(execer))
return types.ErrLocalPrefix return types.ErrLocalPrefix
} }
if !bytes.HasPrefix(key[len(types.LocalPrefix)+1:], execer) { if !bytes.HasPrefix(key[len(types.LocalPrefix)+1:], execer) {
debug.PrintStack()
elog.Error("isAllowLocalKey key prefix not match", "key", string(key), "exec", string(execer)) elog.Error("isAllowLocalKey key prefix not match", "key", string(key), "exec", string(execer))
return types.ErrLocalPrefix return types.ErrLocalPrefix
} }
......
...@@ -89,6 +89,9 @@ func New(cfg *types.Exec, sub map[string][]byte) *Executor { ...@@ -89,6 +89,9 @@ func New(cfg *types.Exec, sub map[string][]byte) *Executor {
return exec return exec
} }
//Wait Executor ready
func (exec *Executor) Wait() {}
// SetQueueClient set client queue, for recv msg // SetQueueClient set client queue, for recv msg
func (exec *Executor) SetQueueClient(qcli queue.Client) { func (exec *Executor) SetQueueClient(qcli queue.Client) {
exec.client = qcli exec.client = qcli
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"encoding/hex" "encoding/hex"
"github.com/33cn/chain33/queue"
_ "github.com/33cn/chain33/system" _ "github.com/33cn/chain33/system"
drivers "github.com/33cn/chain33/system/dapp" drivers "github.com/33cn/chain33/system/dapp"
"github.com/33cn/chain33/types" "github.com/33cn/chain33/types"
...@@ -21,6 +22,11 @@ func init() { ...@@ -21,6 +22,11 @@ func init() {
types.Init("local", nil) types.Init("local", nil)
} }
func TestIsModule(t *testing.T) {
var qmodule queue.Module = &Executor{}
assert.NotNil(t, qmodule)
}
func TestExecutorGetTxGroup(t *testing.T) { func TestExecutorGetTxGroup(t *testing.T) {
exec := &Executor{} exec := &Executor{}
execInit(nil) execInit(nil)
...@@ -113,77 +119,21 @@ func TestKeyAllow_evm(t *testing.T) { ...@@ -113,77 +119,21 @@ func TestKeyAllow_evm(t *testing.T) {
//assert.Nil(t, t) //assert.Nil(t, t)
} }
/*
func TestKeyAllow_evmallow(t *testing.T) {
execInit(nil)
key := []byte("mavl-evm-xxx")
exec := []byte("user.evm.0xc79c9113a71c0a4244e20f0780e7c13552f40ee30b05998a38edb08fe617aaa5")
tx1 := "0a05636f696e73120e18010a0a1080c2d72f1a036f746520a08d0630f1cdebc8f7efa5e9283a22313271796f6361794e46374c7636433971573461767873324537553431664b536676"
tx11, _ := hex.DecodeString(tx1)
var tx12 types.Transaction
types.Decode(tx11, &tx12)
tx12.Execer = exec
if !isAllowKeyWrite(key, exec, &tx12, int64(1)) {
t.Error("user.evm.hash can modify exec")
}
//assert.Nil(t, t)
}
func TestKeyAllow_paraallow(t *testing.T) {
execInit(nil)
key := []byte("mavl-noexec-xxx")
exec := []byte("user.p.user.noexec.0xc79c9113a71c0a4244e20f0780e7c13552f40ee30b05998a38edb08fe617aaa5")
tx1 := "0a05636f696e73120e18010a0a1080c2d72f1a036f746520a08d0630f1cdebc8f7efa5e9283a22313271796f6361794e46374c7636433971573461767873324537553431664b536676"
tx11, _ := hex.DecodeString(tx1)
var tx12 types.Transaction
types.Decode(tx11, &tx12)
tx12.Execer = exec
if isAllowKeyWrite(key, exec, &tx12, int64(1)) {
t.Error("user.noexec.hash can not modify noexec")
}
//assert.Nil(t, t)
}
func TestKeyAllow_ticket(t *testing.T) {
execInit(nil)
key := []byte("mavl-coins-bty-exec-16htvcBNSEA7fZhAdLJphDwQRQJaHpyHTp")
exec := []byte("ticket")
tx1 := "0a067469636b657412c701501022c20108dcaed4f1011080a4a7da061a70314556474572784d52343565577532386d6f4151616e6b34413864516635623639383a3078356161303431643363623561356230396131333336626536373539356638366461336233616564386531653733373139346561353135313562653336363933333a3030303030303030303022423078336461373533326364373839613330623037633538343564336537383433613731356630393961616566386533646161376134383765613135383135336331631a6e08011221025a317f60e6962b7ce9836a83b775373b614b290bee595f8aecee5499791831c21a473045022100850bb15cdcdaf465af7ad1ffcbc1fd6a86942a1ddec1dc112164f37297e06d2d02204aca9686fd169462be955cef1914a225726280739770ab1c0d29eb953e54c6b620a08d0630e3faecf8ead9f9e1483a22313668747663424e53454137665a6841644c4a706844775152514a61487079485470"
tx11, _ := hex.DecodeString(tx1)
var tx12 types.Transaction
types.Decode(tx11, &tx12)
tx12.Execer = exec
if !isAllowKeyWrite(key, exec, &tx12, int64(1)) {
t.Error("ticket can modify exec")
}
}
func TestKeyAllow_paracross(t *testing.T) {
execInit(nil)
key := []byte("mavl-coins-bty-exec-1HPkPopVe3ERfvaAgedDtJQ792taZFEHCe:19xXg1WHzti5hzBRTUphkM8YmuX6jJkoAA")
exec := []byte("paracross")
tx1 := "0a15757365722e702e746573742e7061726163726f7373124310904e223e1080c2d72f1a1374657374206173736574207472616e736665722222314a524e6a64457170344c4a356671796355426d396179434b536565736b674d4b5220a08d0630f7cba7ec9e8f9bac163a2231367a734d68376d764e444b50473645394e5672506877367a4c3933675773547052"
tx11, _ := hex.DecodeString(tx1)
var tx12 types.Transaction
types.Decode(tx11, &tx12)
tx12.Execer = []byte("user.p.para.paracross")
if !isAllowKeyWrite(key, exec, &tx12, int64(1)) {
t.Error("paracross can modify exec")
}
}
*/
func TestKeyLocalAllow(t *testing.T) { func TestKeyLocalAllow(t *testing.T) {
err := isAllowLocalKey([]byte("token"), []byte("LODB-token-")) err := isAllowLocalKey([]byte("token"), []byte("LODB-token-"))
assert.Equal(t, err, types.ErrLocalKeyLen) assert.Equal(t, err, types.ErrLocalKeyLen)
err = isAllowLocalKey([]byte("token"), []byte("LODB-token-a")) err = isAllowLocalKey([]byte("token"), []byte("LODB-token-a"))
assert.Nil(t, err) assert.Nil(t, err)
err = isAllowLocalKey([]byte(""), []byte("LODB--a")) err = isAllowLocalKey([]byte(""), []byte("LODB--a"))
assert.Nil(t, err) assert.Equal(t, err, types.ErrLocalPrefix)
err = isAllowLocalKey([]byte("exec"), []byte("LODB-execaa")) err = isAllowLocalKey([]byte("exec"), []byte("LODB-execaa"))
assert.Equal(t, err, types.ErrLocalPrefix) assert.Equal(t, err, types.ErrLocalPrefix)
err = isAllowLocalKey([]byte("exec"), []byte("-exec------aa")) err = isAllowLocalKey([]byte("exec"), []byte("-exec------aa"))
assert.Equal(t, err, types.ErrLocalPrefix) assert.Equal(t, err, types.ErrLocalPrefix)
err = isAllowLocalKey([]byte("paracross"), []byte("LODB-user.p.para.paracross-xxxx")) err = isAllowLocalKey([]byte("paracross"), []byte("LODB-user.p.para.paracross-xxxx"))
assert.Equal(t, err, types.ErrLocalPrefix) assert.Equal(t, err, types.ErrLocalPrefix)
err = isAllowLocalKey([]byte("user.p.para.paracross"), []byte("LODB-user.p.para.paracross-xxxx"))
assert.Nil(t, err)
err = isAllowLocalKey([]byte("user.p.para.paracross"), []byte("LODB-paracross-xxxx"))
assert.Nil(t, err)
} }
...@@ -184,7 +184,9 @@ func (c *channelClient) GetAllExecBalance(in *types.ReqAddr) (*types.AllExecBala ...@@ -184,7 +184,9 @@ func (c *channelClient) GetAllExecBalance(in *types.ReqAddr) (*types.AllExecBala
addr := in.Addr addr := in.Addr
err := address.CheckAddress(addr) err := address.CheckAddress(addr)
if err != nil { if err != nil {
return nil, types.ErrInvalidAddress if err = address.CheckMultiSignAddress(addr); err != nil {
return nil, types.ErrInvalidAddress
}
} }
var addrs []string var addrs []string
addrs = append(addrs, addr) addrs = append(addrs, addr)
......
...@@ -1229,7 +1229,7 @@ func TestChain33_CreateTransaction(t *testing.T) { ...@@ -1229,7 +1229,7 @@ func TestChain33_CreateTransaction(t *testing.T) {
in := &rpctypes.CreateTxIn{Execer: "notExist", ActionName: "x", Payload: []byte("x")} in := &rpctypes.CreateTxIn{Execer: "notExist", ActionName: "x", Payload: []byte("x")}
err = client.CreateTransaction(in, &result) err = client.CreateTransaction(in, &result)
assert.Equal(t, types.ErrNotSupport, err) assert.Equal(t, types.ErrExecNotFound, err)
in = &rpctypes.CreateTxIn{Execer: types.ExecName("coins"), ActionName: "notExist", Payload: []byte("x")} in = &rpctypes.CreateTxIn{Execer: types.ExecName("coins"), ActionName: "notExist", Payload: []byte("x")}
err = client.CreateTransaction(in, &result) err = client.CreateTransaction(in, &result)
......
...@@ -70,7 +70,7 @@ func (client *JSONClient) Call(method string, params, resp interface{}) error { ...@@ -70,7 +70,7 @@ func (client *JSONClient) Call(method string, params, resp interface{}) error {
req := &clientRequest{} req := &clientRequest{}
req.Method = method req.Method = method
req.Params[0] = params req.Params[0] = params
data, err := json.MarshalIndent(req, "", "\t") data, err := json.Marshal(req)
if err != nil { if err != nil {
return err return err
} }
...@@ -116,5 +116,4 @@ func (client *JSONClient) Call(method string, params, resp interface{}) error { ...@@ -116,5 +116,4 @@ func (client *JSONClient) Call(method string, params, resp interface{}) error {
return types.JSONToPB(b, msg) return types.JSONToPB(b, msg)
} }
return json.Unmarshal(*cresp.Result, resp) return json.Unmarshal(*cresp.Result, resp)
} }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package rpc_test package rpc_test
import ( import (
"fmt"
"testing" "testing"
"github.com/33cn/chain33/common" "github.com/33cn/chain33/common"
...@@ -97,3 +98,39 @@ func TestSendToExec(t *testing.T) { ...@@ -97,3 +98,39 @@ func TestSendToExec(t *testing.T) {
balance := mocker.GetExecAccount(block.StateHash, "user.f3d", mocker.GetGenesisAddress()).Balance balance := mocker.GetExecAccount(block.StateHash, "user.f3d", mocker.GetGenesisAddress()).Balance
assert.Equal(t, int64(10), balance) assert.Equal(t, int64(10), balance)
} }
func TestGetAllExecBalance(t *testing.T) {
mocker := testnode.New("--free--", nil)
defer mocker.Close()
mocker.Listen()
jrpcClient := getRPCClient(t, mocker)
addr := "38BRY193Wvy9MkdqMjmuaYeUHnJaFjUxMP"
req := types.ReqAddr{Addr: addr}
var res rpctypes.AllExecBalance
err := jrpcClient.Call("Chain33.GetAllExecBalance", req, &res)
assert.Nil(t, err)
assert.Equal(t, addr, res.Addr)
assert.Nil(t, res.ExecAccount)
assert.Equal(t, 0, len(res.ExecAccount))
}
func TestCreateTransactionUserWrite(t *testing.T) {
mocker := testnode.New("--free--", nil)
defer mocker.Close()
mocker.Listen()
jrpcClient := getRPCClient(t, mocker)
req := &rpctypes.CreateTxIn{
Execer: "user.write",
ActionName: "write",
Payload: []byte(`{"key":"value"}`),
}
var res string
err := jrpcClient.Call("Chain33.CreateTransaction", req, &res)
assert.Nil(t, err)
tx := getTx(t, res)
assert.NotNil(t, tx)
fmt.Println(string(tx.Payload))
assert.Nil(t, err)
assert.Equal(t, `{"key":"value"}`, string(tx.Payload))
}
...@@ -138,8 +138,10 @@ func balance(cmd *cobra.Command, args []string) { ...@@ -138,8 +138,10 @@ func balance(cmd *cobra.Command, args []string) {
height, _ := cmd.Flags().GetInt("height") height, _ := cmd.Flags().GetInt("height")
err := address.CheckAddress(addr) err := address.CheckAddress(addr)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, types.ErrInvalidAddress) if err = address.CheckMultiSignAddress(addr); err != nil {
return fmt.Fprintln(os.Stderr, types.ErrInvalidAddress)
return
}
} }
if execer == "" { if execer == "" {
req := types.ReqAddr{Addr: addr} req := types.ReqAddr{Addr: addr}
......
...@@ -59,6 +59,7 @@ var ( ...@@ -59,6 +59,7 @@ var (
ErrEmpty = errors.New("ErrEmpty") ErrEmpty = errors.New("ErrEmpty")
ErrSendSameToRecv = errors.New("ErrSendSameToRecv") ErrSendSameToRecv = errors.New("ErrSendSameToRecv")
ErrExecNameNotAllow = errors.New("ErrExecNameNotAllow") ErrExecNameNotAllow = errors.New("ErrExecNameNotAllow")
ErrExecNotFound = errors.New("ErrExecNotFound")
ErrLocalDBPerfix = errors.New("ErrLocalDBPerfix") ErrLocalDBPerfix = errors.New("ErrLocalDBPerfix")
ErrTimeout = errors.New("ErrTimeout") ErrTimeout = errors.New("ErrTimeout")
ErrBlockHeaderDifficulty = errors.New("ErrBlockHeaderDifficulty") ErrBlockHeaderDifficulty = errors.New("ErrBlockHeaderDifficulty")
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package types package types
import ( import (
"bytes"
"encoding/json" "encoding/json"
"math/rand" "math/rand"
"reflect" "reflect"
...@@ -137,8 +138,14 @@ func CallCreateTx(execName, action string, param Message) ([]byte, error) { ...@@ -137,8 +138,14 @@ func CallCreateTx(execName, action string, param Message) ([]byte, error) {
func CallCreateTxJSON(execName, action string, param json.RawMessage) ([]byte, error) { func CallCreateTxJSON(execName, action string, param json.RawMessage) ([]byte, error) {
exec := LoadExecutorType(execName) exec := LoadExecutorType(execName)
if exec == nil { if exec == nil {
execer := GetParaExecName([]byte(execName))
//找不到执行器,并且是user.xxx 的情况下
if bytes.HasPrefix(execer, UserKey) {
tx := &Transaction{Payload: param}
return FormatTxEncode(execName, tx)
}
tlog.Error("CallCreateTxJSON", "Error", "exec not found") tlog.Error("CallCreateTxJSON", "Error", "exec not found")
return nil, ErrNotSupport return nil, ErrExecNotFound
} }
// param is interface{type, var-nil}, check with nil always fail // param is interface{type, var-nil}, check with nil always fail
if param == nil { if param == nil {
......
...@@ -142,7 +142,8 @@ func GetParaExec(execer []byte) []byte { ...@@ -142,7 +142,8 @@ func GetParaExec(execer []byte) []byte {
return execer[len(GetTitle()):] return execer[len(GetTitle()):]
} }
func getParaExecName(execer []byte) []byte { //GetParaExecName 获取平行链上的执行器
func GetParaExecName(execer []byte) []byte {
if !bytes.HasPrefix(execer, ParaKey) { if !bytes.HasPrefix(execer, ParaKey) {
return execer return execer
} }
...@@ -162,7 +163,7 @@ func getParaExecName(execer []byte) []byte { ...@@ -162,7 +163,7 @@ func getParaExecName(execer []byte) []byte {
//GetRealExecName 获取真实的执行器name //GetRealExecName 获取真实的执行器name
func GetRealExecName(execer []byte) []byte { func GetRealExecName(execer []byte) []byte {
//平行链执行器,获取真实执行器的规则 //平行链执行器,获取真实执行器的规则
execer = getParaExecName(execer) execer = GetParaExecName(execer)
//平行链嵌套平行链是不被允许的 //平行链嵌套平行链是不被允许的
if bytes.HasPrefix(execer, ParaKey) { if bytes.HasPrefix(execer, ParaKey) {
return execer return execer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment