beancount-gs/server.go

248 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"flag"
"fmt"
"io"
"net/http"
"os"
"github.com/beancount-gs/script"
"github.com/beancount-gs/service"
"github.com/beancount-gs/utils/venv"
"github.com/gin-gonic/gin"
)
// 全局变量,方便其他模块使用
var venvExecutor *venv.VenvExecutor
var venvPath string // 新增:虚拟环境路径变量
/*
* 初始化服务器文件
* 检查账本目录是否存在,如果不存在则创建
* 返回error表示操作是否成功
*/
func InitServerFiles() error {
dataPath := script.GetServerConfig().DataPath
// 账本目录不存在,则创建
if dataPath != "" && !script.FileIfExist(dataPath) {
return script.MkDir(dataPath)
}
return nil
}
/*
* 加载服务器缓存
* 加载账本配置映射和账户映射
* 返回error表示加载过程中是否出错
*/
func LoadServerCache() error {
err := script.LoadLedgerConfigMap()
if err != nil {
return err
}
return script.LoadLedgerAccountsMap()
}
/*
* 授权中间件
* 检查请求头中的ledgerId是否有效
* 如果有效则继续处理请求,否则返回未授权错误
*/
func AuthorizedHandler() gin.HandlerFunc {
return func(c *gin.Context) {
ledgerId := c.GetHeader("ledgerId")
ledgerConfig := script.GetLedgerConfig(ledgerId)
if ledgerConfig != nil {
c.Set("LedgerConfig", ledgerConfig)
c.Next()
} else {
service.Unauthorized(c)
c.Abort()
}
}
}
/*
* 注册路由
* 配置静态文件服务、API路由和需要授权的路由组
*/
func RegisterRouter(router *gin.Engine) {
// fix wildcard and static file router conflict, https://github.com/gin-gonic/gin/issues/360
router.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/web")
})
router.StaticFS("/web", http.Dir("./public"))
// 公开API路由无需授权
router.GET("/api/version", service.QueryVersion)
router.POST("/api/check", service.CheckBeancount)
router.GET("/api/config", service.QueryServerConfig)
router.POST("/api/config", service.UpdateServerConfig)
router.GET("/api/ledger", service.QueryLedgerList)
router.POST("/api/ledger", service.OpenOrCreateLedger)
// 需要授权的API路由组
authorized := router.Group("/api/auth/")
authorized.Use(AuthorizedHandler())
{
// need authorized
authorized.GET("/account/valid", service.QueryValidAccount)
authorized.GET("/account/all", service.QueryAllAccount)
authorized.GET("/account/type", service.QueryAccountType)
authorized.POST("/account", service.AddAccount)
authorized.POST("/account/type", service.AddAccountType)
authorized.POST("/account/close", service.CloseAccount)
authorized.POST("/account/icon", service.ChangeAccountIcon)
authorized.POST("/account/balance", service.BalanceAccount)
authorized.POST("/account/refresh", service.RefreshAccountCache)
authorized.POST("/commodity/price", service.SyncCommodityPrice)
authorized.GET("/commodity/currencies", service.QueryAllCurrencies)
authorized.GET("/stats/months", service.MonthsList)
authorized.GET("/stats/total", service.StatsTotal)
authorized.GET("/stats/payee", service.StatsPayee)
authorized.GET("/stats/account/percent", service.StatsAccountPercent)
authorized.GET("/stats/account/trend", service.StatsAccountTrend)
authorized.GET("/stats/account/balance", service.StatsAccountBalance)
authorized.GET("/stats/account/flow", service.StatsAccountSankey)
authorized.GET("/stats/month/total", service.StatsMonthTotal)
authorized.GET("/stats/month/calendar", service.StatsMonthCalendar)
authorized.GET("/stats/commodity/price", service.StatsCommodityPrice)
authorized.GET("/transaction/detail", service.QueryTransactionDetailById)
authorized.GET("/transaction/raw", service.QueryTransactionRawTextById)
authorized.GET("/transaction", service.QueryTransactions)
authorized.POST("/transaction", service.AddTransactions)
authorized.POST("/transaction/raw", service.UpdateTransactionRawTextById)
authorized.DELETE("/transaction", service.DeleteTransactionById)
authorized.POST("/transaction/batch", service.AddBatchTransactions)
authorized.GET("/transaction/payee", service.QueryTransactionPayees)
authorized.GET("/transaction/template", service.QueryTransactionTemplates)
authorized.POST("/transaction/template", service.AddTransactionTemplate)
authorized.DELETE("/transaction/template", service.DeleteTransactionTemplate)
authorized.GET("/event/all", service.GetAllEvents)
authorized.POST("/event", service.AddEvent)
authorized.DELETE("/event", service.DeleteEvent)
authorized.GET("/tags", service.QueryTags)
authorized.GET("/file/dir", service.QueryLedgerSourceFileDir)
authorized.GET("/file/content", service.QueryLedgerSourceFileContent)
authorized.POST("/file", service.UpdateLedgerSourceFileContent)
authorized.POST("/import/alipay", service.ImportAliPayCSV)
authorized.POST("/import/wx", service.ImportWxPayCSV)
authorized.POST("/import/icbc", service.ImportICBCCSV)
authorized.POST("/import/abc", service.ImportABCCSV)
authorized.GET("/ledger/check", service.CheckLedger)
authorized.DELETE("/ledger", service.DeleteLedger)
}
}
// initVenvExecutor 初始化虚拟环境执行器
func initVenvExecutor(venvDir string) {
venvPath = venvDir
script.SetVenvPath(venvDir) // 设置路径到 script 包
// 检查虚拟环境是否存在
if !venv.CheckVenvExists(venvPath) {
script.LogSystemError("虚拟环境不存在,请先运行 setup script: " + venvPath)
fmt.Println("警告: 虚拟环境不存在,某些功能可能无法正常工作")
fmt.Println("请运行: ./start_dev.sh 或手动创建虚拟环境")
return
}
venvExecutor = venv.NewVenvExecutor(venvPath)
script.SetVenvExecutor(venvExecutor) // 设置执行器到 script 包
// 测试 bean-query 是否可用
_, err := venvExecutor.GetCommandPath("bean-query")
if err != nil {
script.LogSystemError("bean-query 不可用: " + err.Error())
fmt.Println("警告: bean-query 命令不可用,价格查询功能将受限")
} else {
script.LogSystemInfo("虚拟环境初始化成功: bean-query 可用, 路径: " + venvPath)
fmt.Println("虚拟环境初始化成功: " + venvPath)
}
}
func main() {
var secret string
var port int
var debugFlag bool
var venvDir string // 新增:虚拟环境目录参数
flag.StringVar(&secret, "secret", "", "服务器密钥")
flag.IntVar(&port, "p", 10000, "端口号")
flag.BoolVar(&debugFlag, "debug", false, "调试模式")
flag.StringVar(&venvDir, "venv", ".env_beancount-v3", "虚拟环境目录名称,默认值为 .env_beancount-v3") // 新增参数
flag.Parse()
// 初始化虚拟环境执行器
initVenvExecutor(venvDir)
// 读取配置文件
err := script.LoadServerConfig()
if err != nil {
script.LogSystemError("Failed to load server config, " + err.Error())
return
}
// 如果命令行指定了debug参数覆盖配置文件中的设置
if debugFlag {
err = script.SetDebugMode(true)
if err != nil {
fmt.Println("Warning: Failed to set debug mode:", err)
}
}
// 现在可以在任何地方使用 script.IsDebugMode() 来检查调试模式
if script.IsDebugMode() {
fmt.Println("调试模式已启用")
} else {
fmt.Println("调试模式未启用")
}
serverConfig := script.GetServerConfig()
// 若 DataPath == "" 则配置未初始化
if serverConfig.DataPath != "" {
// 初始化账本文件结构
err = InitServerFiles()
if err != nil {
script.LogSystemError("Failed to init server files, " + err.Error())
return
}
// 加载缓存
err = LoadServerCache()
if err != nil {
script.LogSystemError("Failed to load server cache, " + err.Error())
return
}
}
// gin 日志设置
gin.DisableConsoleColor()
fs, _ := os.Create("logs/gin.log")
gin.DefaultWriter = io.MultiWriter(fs, os.Stdout)
router := gin.Default()
// 注册路由
RegisterRouter(router)
portStr := fmt.Sprintf(":%d", port)
url := "http://localhost" + portStr
ip := script.GetIpAddress()
startLog := "beancount-gs start at " + url
if ip != "" {
startLog += " or http://" + ip + portStr
}
script.LogSystemInfo(startLog)
// 打开浏览器
script.OpenBrowser(url)
// 打印密钥
script.LogSystemInfo("Secret token is " + script.GenerateServerSecret(secret))
// 启动服务
err = router.Run(portStr)
if err != nil {
script.LogSystemError("Failed to start server, " + err.Error())
}
}