Commit 77243bdc authored by linj's avatar linj Committed by vipwzw

trade upgrade return kvset

parent e40b85b8
...@@ -28,52 +28,63 @@ const ( ...@@ -28,52 +28,63 @@ const (
) )
// Upgrade 实现升级接口 // Upgrade 实现升级接口
func (t *trade) Upgrade() error { func (t *trade) Upgrade() (*types.LocalDBSet, error) {
localDB := t.GetLocalDB() localDB := t.GetLocalDB()
// 获得默认的coins symbol, 更新数据时用 // 获得默认的coins symbol, 更新数据时用
coinSymbol := t.GetAPI().GetConfig().GetCoinSymbol() coinSymbol := t.GetAPI().GetConfig().GetCoinSymbol()
err := UpgradeLocalDBV2(localDB, coinSymbol) kvs, err := UpgradeLocalDBV2(localDB, coinSymbol)
if err != nil { if err != nil {
tradelog.Error("Upgrade failed", "err", err) tradelog.Error("Upgrade failed", "err", err)
return errors.Cause(err) return nil, errors.Cause(err)
} }
return nil return kvs, nil
} }
// UpgradeLocalDBV2 trade 本地数据库升级 // UpgradeLocalDBV2 trade 本地数据库升级
// from 1 to 2 // from 1 to 2
func UpgradeLocalDBV2(localDB dbm.KVDB, coinSymbol string) error { func UpgradeLocalDBV2(localDB dbm.KVDB, coinSymbol string) (*types.LocalDBSet, error) {
toVersion := 2 toVersion := 2
tradelog.Info("UpgradeLocalDBV2 upgrade start", "to_version", toVersion) tradelog.Info("UpgradeLocalDBV2 upgrade start", "to_version", toVersion)
version, err := getVersion(localDB) version, err := getVersion(localDB)
if err != nil { if err != nil {
return errors.Wrap(err, "UpgradeLocalDBV2 get version") return nil, errors.Wrap(err, "UpgradeLocalDBV2 get version")
} }
if version >= toVersion { if version >= toVersion {
tradelog.Debug("UpgradeLocalDBV2 not need to upgrade", "current_version", version, "to_version", toVersion) tradelog.Debug("UpgradeLocalDBV2 not need to upgrade", "current_version", version, "to_version", toVersion)
return nil return nil, nil
} }
err = UpgradeLocalDBPart2(localDB, coinSymbol) var kvset types.LocalDBSet
kvs, err := UpgradeLocalDBPart2(localDB, coinSymbol)
if err != nil { if err != nil {
return errors.Wrap(err, "UpgradeLocalDBV2 UpgradeLocalDBPart2") return nil, errors.Wrap(err, "UpgradeLocalDBV2 UpgradeLocalDBPart2")
}
if kvs != nil && len(kvs) > 0 {
kvset.KV = append(kvset.KV, kvs...)
} }
err = UpgradeLocalDBPart1(localDB) kvs, err = UpgradeLocalDBPart1(localDB)
if err != nil { if err != nil {
return errors.Wrap(err, "UpgradeLocalDBV2 UpgradeLocalDBPart1") return nil, errors.Wrap(err, "UpgradeLocalDBV2 UpgradeLocalDBPart1")
}
if kvs != nil && len(kvs) > 0 {
kvset.KV = append(kvset.KV, kvs...)
} }
err = setVersion(localDB, toVersion)
kvs, err = setVersion(localDB, toVersion)
if err != nil { if err != nil {
return errors.Wrap(err, "UpgradeLocalDBV2 setVersion") return nil, errors.Wrap(err, "UpgradeLocalDBV2 setVersion")
}
if kvs != nil && len(kvs) > 0 {
kvset.KV = append(kvset.KV, kvs...)
} }
tradelog.Info("UpgradeLocalDBV2 upgrade done") tradelog.Info("UpgradeLocalDBV2 upgrade done")
return nil return &kvset, nil
} }
// UpgradeLocalDBPart1 手动生成KV,需要在原有数据库中删除 // UpgradeLocalDBPart1 手动生成KV,需要在原有数据库中删除
func UpgradeLocalDBPart1(localDB dbm.KVDB) error { func UpgradeLocalDBPart1(localDB dbm.KVDB) ([]*types.KeyValue, error) {
prefixes := []string{ prefixes := []string{
sellOrderSHTAS, sellOrderSHTAS,
sellOrderASTS, sellOrderASTS,
...@@ -86,44 +97,52 @@ func UpgradeLocalDBPart1(localDB dbm.KVDB) error { ...@@ -86,44 +97,52 @@ func UpgradeLocalDBPart1(localDB dbm.KVDB) error {
orderASTHK, orderASTHK,
} }
var allKvs []*types.KeyValue
for _, prefix := range prefixes { for _, prefix := range prefixes {
err := delOnePrefix(localDB, prefix) kvs, err := delOnePrefix(localDB, prefix)
if err != nil { if err != nil {
return errors.Wrapf(err, "UpdateLocalDBPart1 delOnePrefix: %s", prefix) return nil, errors.Wrapf(err, "UpdateLocalDBPart1 delOnePrefix: %s", prefix)
} }
if kvs != nil && len(kvs) > 0 {
allKvs = append(allKvs, kvs...)
}
} }
return nil return allKvs, nil
} }
// delOnePrefix 删除指定前缀的记录 // delOnePrefix 删除指定前缀的记录
func delOnePrefix(localDB dbm.KVDB, prefix string) error { func delOnePrefix(localDB dbm.KVDB, prefix string) ([]*types.KeyValue, error) {
start := []byte(prefix) start := []byte(prefix)
keys, err := localDB.List(start, nil, 0, dbm.ListASC|dbm.ListKeyOnly) keys, err := localDB.List(start, nil, 0, dbm.ListASC|dbm.ListKeyOnly)
if err != nil { if err != nil {
if err == types.ErrNotFound { if err == types.ErrNotFound {
return nil return nil, nil
} }
return err return nil, err
} }
var kvs []*types.KeyValue
tradelog.Debug("delOnePrefix", "len", len(keys), "prefix", prefix) tradelog.Debug("delOnePrefix", "len", len(keys), "prefix", prefix)
for _, key := range keys { for _, key := range keys {
err = localDB.Set(key, nil) err = localDB.Set(key, nil)
if err != nil { if err != nil {
return err return nil, err
} }
kvs = append(kvs, &types.KeyValue{Key: key, Value: nil})
} }
return nil return kvs, nil
} }
// UpgradeLocalDBPart2 升级order // UpgradeLocalDBPart2 升级order
// order 从 v1 升级到 v2 // order 从 v1 升级到 v2
// 通过tableV1 删除, 通过tableV2 添加, 无需通过每个区块扫描对应的交易 // 通过tableV1 删除, 通过tableV2 添加, 无需通过每个区块扫描对应的交易
func UpgradeLocalDBPart2(kvdb dbm.KVDB, coinSymbol string) error { func UpgradeLocalDBPart2(kvdb dbm.KVDB, coinSymbol string) ([]*types.KeyValue, error) {
return upgradeOrder(kvdb, coinSymbol) return upgradeOrder(kvdb, coinSymbol)
} }
func upgradeOrder(kvdb dbm.KVDB, coinSymbol string) (err error) { func upgradeOrder(kvdb dbm.KVDB, coinSymbol string) ([]*types.KeyValue, error) {
tab2 := NewOrderTableV2(kvdb) tab2 := NewOrderTableV2(kvdb)
tab := NewOrderTable(kvdb) tab := NewOrderTable(kvdb)
q1 := tab.GetQuery(kvdb) q1 := tab.GetQuery(kvdb)
...@@ -132,38 +151,38 @@ func upgradeOrder(kvdb dbm.KVDB, coinSymbol string) (err error) { ...@@ -132,38 +151,38 @@ func upgradeOrder(kvdb dbm.KVDB, coinSymbol string) (err error) {
rows, err := q1.List("key", &order1, []byte(""), 0, 0) rows, err := q1.List("key", &order1, []byte(""), 0, 0)
if err != nil { if err != nil {
if err == types.ErrNotFound { if err == types.ErrNotFound {
return nil return nil, nil
} }
return errors.Wrap(err, "upgradeOrder list from order v1 table") return nil, errors.Wrap(err, "upgradeOrder list from order v1 table")
} }
tradelog.Debug("upgradeOrder", "len", len(rows)) tradelog.Debug("upgradeOrder", "len", len(rows))
for _, row := range rows { for _, row := range rows {
o1, ok := row.Data.(*pty.LocalOrder) o1, ok := row.Data.(*pty.LocalOrder)
if !ok { if !ok {
return errors.Wrap(types.ErrTypeAsset, "decode order v1") return nil, errors.Wrap(types.ErrTypeAsset, "decode order v1")
} }
o2 := types.Clone(o1).(*pty.LocalOrder) o2 := types.Clone(o1).(*pty.LocalOrder)
upgradeLocalOrder(o2, coinSymbol) upgradeLocalOrder(o2, coinSymbol)
err = tab2.Add(o2) err = tab2.Add(o2)
if err != nil { if err != nil {
return errors.Wrap(err, "upgradeOrder add to order v2 table") return nil, errors.Wrap(err, "upgradeOrder add to order v2 table")
} }
err = tab.Del([]byte(o1.TxIndex)) err = tab.Del([]byte(o1.TxIndex))
if err != nil { if err != nil {
return errors.Wrapf(err, "upgradeOrder del from order v1 table, key: %s", o1.TxIndex) return nil, errors.Wrapf(err, "upgradeOrder del from order v1 table, key: %s", o1.TxIndex)
} }
} }
kvs, err := tab2.Save() kvs, err := tab2.Save()
if err != nil { if err != nil {
return errors.Wrap(err, "upgradeOrder save-add to order v2 table") return nil, errors.Wrap(err, "upgradeOrder save-add to order v2 table")
} }
kvs2, err := tab.Save() kvs2, err := tab.Save()
if err != nil { if err != nil {
return errors.Wrap(err, "upgradeOrder save-del to order v1 table") return nil, errors.Wrap(err, "upgradeOrder save-del to order v1 table")
} }
kvs = append(kvs, kvs2...) kvs = append(kvs, kvs2...)
...@@ -171,11 +190,11 @@ func upgradeOrder(kvdb dbm.KVDB, coinSymbol string) (err error) { ...@@ -171,11 +190,11 @@ func upgradeOrder(kvdb dbm.KVDB, coinSymbol string) (err error) {
tradelog.Debug("upgradeOrder", "KEY", string(kv.GetKey())) tradelog.Debug("upgradeOrder", "KEY", string(kv.GetKey()))
err = kvdb.Set(kv.GetKey(), kv.GetValue()) err = kvdb.Set(kv.GetKey(), kv.GetValue())
if err != nil { if err != nil {
return errors.Wrapf(err, "upgradeOrder set localdb key: %s", string(kv.GetKey())) return nil, errors.Wrapf(err, "upgradeOrder set localdb key: %s", string(kv.GetKey()))
} }
} }
return nil return kvs, nil
} }
// upgradeLocalOrder 处理两个fork前的升级数据 // upgradeLocalOrder 处理两个fork前的升级数据
...@@ -208,8 +227,9 @@ func getVersion(kvdb dbm.KV) (int, error) { ...@@ -208,8 +227,9 @@ func getVersion(kvdb dbm.KV) (int, error) {
return int(v.Data), nil return int(v.Data), nil
} }
func setVersion(kvdb dbm.KV, version int) error { func setVersion(kvdb dbm.KV, version int) ([]*types.KeyValue, error) {
v := types.Int32{Data: int32(version)} v := types.Int32{Data: int32(version)}
x := types.Encode(&v) x := types.Encode(&v)
return kvdb.Set([]byte(tradeLocaldbVersioin), x) err := kvdb.Set([]byte(tradeLocaldbVersioin), x)
return []*types.KeyValue{{Key: []byte(tradeLocaldbVersioin), Value: x}}, err
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment