diff --git a/script/bql.go b/script/bql.go index 2573990..e02ab46 100644 --- a/script/bql.go +++ b/script/bql.go @@ -17,6 +17,8 @@ type QueryParams struct { FromYear int `bql:"year ="` FromMonth int `bql:"month ="` Where bool `bql:"where"` + ID string `bql:"id ="` + IDList string `bql:"id in"` Currency string `bql:"currency ="` Year int `bql:"year ="` Month int `bql:"month ="` diff --git a/service/stats.go b/service/stats.go index b2e798e..4388f4e 100644 --- a/service/stats.go +++ b/service/stats.go @@ -3,9 +3,7 @@ package service import ( "encoding/json" "fmt" - "math" "sort" - "strconv" "strings" "time" @@ -65,7 +63,7 @@ func StatsTotal(c *gin.Context) { } type StatsQuery struct { - Prefix string `form:"prefix" binding:"required"` + Prefix string `form:"prefix"` Year int `form:"year"` Month int `form:"month"` Level int `form:"level"` @@ -252,32 +250,104 @@ type AccountSankeyNode struct { Name string `json:"name"` } type AccountSankeyLink struct { - Source int `json:"source"` - Target int `json:"target"` - Value string `json:"value"` + Source int `json:"source"` + Target int `json:"target"` + Value decimal.Decimal `json:"value"` } func NewAccountSankeyLink() *AccountSankeyLink { return &AccountSankeyLink{ Source: -1, Target: -1, - Value: "", } } +type TransactionAccountPositionBQLResult struct { + Id string + Account string + Position string +} + +type TransactionAccountPosition struct { + Id string + Account string + AccountName string + Value decimal.Decimal + OperatingCurrency string +} + +// StatsAccountSankey 统计账户流向 func StatsAccountSankey(c *gin.Context) { ledgerConfig := script.GetLedgerConfigFromContext(c) - queryParams := script.GetQueryParams(c) - // 倒序查询 - queryParams.OrderBy = "date desc" - transactions := make([]Transaction, 0) - err := script.BQLQueryList(ledgerConfig, &queryParams, &transactions) + var statsQuery StatsQuery + if err := c.ShouldBindQuery(&statsQuery); err != nil { + BadRequest(c, err.Error()) + return + } + queryParams := script.QueryParams{ + AccountLike: statsQuery.Prefix, + Year: statsQuery.Year, + Month: statsQuery.Month, + Where: true, + } + statsQueryResultList := make([]TransactionAccountPositionBQLResult, 0) + var bql string + // 账户不为空,则查询时间范围内所有涉及该账户的交易记录 + if statsQuery.Prefix != "" { + // 清空 account 查询条件,改为使用 ID 查询包含该账户所有交易记录 + queryParams.AccountLike = "" + bql = "SELECT '\\', id, '\\'" + err := script.BQLQueryListByCustomSelect(ledgerConfig, bql, &queryParams, &statsQueryResultList) + if err != nil { + InternalError(c, err.Error()) + return + } + if len(statsQueryResultList) != 0 { + idSet := make(map[string]bool) + for _, bqlResult := range statsQueryResultList { + idSet[bqlResult.Id] = true + } + idList := make([]string, 0, len(idSet)) + for id := range idSet { + idList = append(idList, id) + } + queryParams.IDList = strings.Join(idList, "|") + } + } + + if statsQuery.Level != 0 { + prefixNodeLen := len(strings.Split(strings.Trim(statsQuery.Prefix, ":"), ":")) + bql = fmt.Sprintf("SELECT '\\', id, '\\', root(account, %d) as subAccount, '\\', sum(convert(value(position), '%s')), '\\'", statsQuery.Level+prefixNodeLen, ledgerConfig.OperatingCurrency) + } else { + bql = fmt.Sprintf("SELECT '\\', id, '\\', account, '\\', sum(convert(value(position), '%s')), '\\'", ledgerConfig.OperatingCurrency) + } + + statsQueryResultList = make([]TransactionAccountPositionBQLResult, 0) + err := script.BQLQueryListByCustomSelect(ledgerConfig, bql, &queryParams, &statsQueryResultList) if err != nil { InternalError(c, err.Error()) return } + result := make([]Transaction, 0) + for _, queryRes := range statsQueryResultList { + if queryRes.Position != "" { + fields := strings.Fields(queryRes.Position) + result = append(result, Transaction{ + Id: queryRes.Id, + Account: queryRes.Account, + Number: fields[0], + Currency: fields[1], + }) + } + } + OK(c, buildSankeyResult(result)) +} + +func buildSankeyResult(transactions []Transaction) AccountSankeyResult { accountSankeyResult := AccountSankeyResult{} + accountSankeyResult.Nodes = make([]AccountSankeyNode, 0) + accountSankeyResult.Links = make([]AccountSankeyLink, 0) // 构建 nodes 和 links var nodes []AccountSankeyNode @@ -285,9 +355,9 @@ func StatsAccountSankey(c *gin.Context) { if len(transactions) > 0 { for _, transaction := range transactions { // 如果nodes中不存在该节点,则添加 - accountName := script.GetAccountName(transaction.Account) - if !contains(nodes, accountName) { - nodes = append(nodes, AccountSankeyNode{Name: accountName}) + account := transaction.Account + if !contains(nodes, account) { + nodes = append(nodes, AccountSankeyNode{Name: account}) } } accountSankeyResult.Nodes = nodes @@ -311,31 +381,30 @@ func StatsAccountSankey(c *gin.Context) { transaction := transactions[0] transactions = transactions[1:] - accountName := script.GetAccountName(transaction.Account) - num, err := strconv.ParseFloat(transaction.Number, 64) + account := transaction.Account + num, err := decimal.NewFromString(transaction.Number) if err != nil { continue } - if currentLinkNode.Source == -1 && num < 0 { + if currentLinkNode.Source == -1 && num.IsNegative() { if sourceTransaction.Account == "" { sourceTransaction = transaction } - currentLinkNode.Source = indexOf(nodes, accountName) + currentLinkNode.Source = indexOf(nodes, account) if currentLinkNode.Target == -1 { - currentLinkNode.Value = strconv.FormatFloat(num, 'f', 2, 64) + currentLinkNode.Value = num } else { // 比较 link node value 和 num 大小 - value, _ := strconv.ParseFloat(currentLinkNode.Value, 64) - delta := value + num - if delta == 0 { - currentLinkNode.Value = strconv.FormatFloat(math.Abs(num), 'f', 2, 64) - } else if delta < 0 { // source > target - targetNumber, _ := strconv.ParseFloat(targetTransaction.Number, 64) - currentLinkNode.Value = strconv.FormatFloat(math.Abs(targetNumber), 'f', 2, 64) - sourceTransaction.Number = strconv.FormatFloat(delta, 'f', 2, 64) + delta := currentLinkNode.Value.Add(num) + if delta.IsZero() { + currentLinkNode.Value = num.Abs() + } else if delta.IsNegative() { // source > target + targetNumber, _ := decimal.NewFromString(targetTransaction.Number) + currentLinkNode.Value = targetNumber.Abs() + sourceTransaction.Number = delta.String() transactions = append(transactions, sourceTransaction) } else { // source < target - targetTransaction.Number = strconv.FormatFloat(delta, 'f', 2, 64) + targetTransaction.Number = delta.String() transactions = append(transactions, targetTransaction) } // 完成一个 linkNode 的构建,重置判定条件 @@ -344,26 +413,25 @@ func StatsAccountSankey(c *gin.Context) { links = append(links, *currentLinkNode) currentLinkNode = NewAccountSankeyLink() } - } else if currentLinkNode.Target == -1 && num > 0 { + } else if currentLinkNode.Target == -1 && num.IsPositive() { if targetTransaction.Account == "" { targetTransaction = transaction } - currentLinkNode.Target = indexOf(nodes, accountName) + currentLinkNode.Target = indexOf(nodes, account) if currentLinkNode.Source == -1 { - currentLinkNode.Value = strconv.FormatFloat(num, 'f', 2, 64) + currentLinkNode.Value = num } else { - value, _ := strconv.ParseFloat(currentLinkNode.Value, 64) - delta := value + num - if delta == 0 { - currentLinkNode.Value = strconv.FormatFloat(math.Abs(num), 'f', 2, 64) - } else if delta < 0 { // source > target - currentLinkNode.Value = strconv.FormatFloat(math.Abs(num), 'f', 2, 64) - sourceTransaction.Number = strconv.FormatFloat(delta, 'f', 2, 64) + delta := currentLinkNode.Value.Add(num) + if delta.IsZero() { + currentLinkNode.Value = num.Abs() + } else if delta.IsNegative() { // source > target + currentLinkNode.Value = num.Abs() + sourceTransaction.Number = delta.String() transactions = append(transactions, sourceTransaction) } else { // source < target - sourceNumber, _ := strconv.ParseFloat(sourceTransaction.Number, 64) - currentLinkNode.Value = strconv.FormatFloat(math.Abs(sourceNumber), 'f', 2, 64) - targetTransaction.Number = strconv.FormatFloat(delta, 'f', 2, 64) + sourceNumber, _ := decimal.NewFromString(sourceTransaction.Number) + currentLinkNode.Value = sourceNumber.Abs() + targetTransaction.Number = delta.String() transactions = append(transactions, targetTransaction) } // 完成一个 linkNode 的构建,重置判定条件 @@ -379,10 +447,9 @@ func StatsAccountSankey(c *gin.Context) { maxCycle -= 1 } } - accountSankeyResult.Links = links + accountSankeyResult.Links = aggregateLinkNodes(links) } - - OK(c, accountSankeyResult) + return accountSankeyResult } func contains(nodes []AccountSankeyNode, str string) bool { @@ -415,6 +482,33 @@ func groupTransactionsByID(transactions []Transaction) map[string][]Transaction return grouped } +// 聚合函数,聚合相同 source 和 target 的值 +func aggregateLinkNodes(nodes []AccountSankeyLink) []AccountSankeyLink { + // 创建一个映射 key 为 "source-target",value 为 LinkNode + nodeMap := make(map[string]AccountSankeyLink) + + for _, node := range nodes { + key := fmt.Sprintf("%d-%d", node.Source, node.Target) + + if existingNode, found := nodeMap[key]; found { + // 如果已经存在相同的 source 和 target,累加 value + existingNode.Value = existingNode.Value.Add(node.Value) + nodeMap[key] = existingNode + } else { + // 否则直接插入新的 LinkNode + nodeMap[key] = node + } + } + + // 将 map 转换为 slice + result := make([]AccountSankeyLink, 0, len(nodeMap)) + for _, aggregatedNode := range nodeMap { + result = append(result, aggregatedNode) + } + + return result +} + type MonthTotalBQLResult struct { Year int Month int