137 lines
3.2 KiB
Go
137 lines
3.2 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"sync"
|
||
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
var (
|
||
config *Config
|
||
once sync.Once
|
||
)
|
||
|
||
func Init(args []string) {
|
||
_initConfig(args)
|
||
}
|
||
|
||
type Config struct {
|
||
Server struct {
|
||
Host string `yaml:"host"`
|
||
Port int `yaml:"port"`
|
||
Env string `yaml:"env"`
|
||
} `yaml:"server"`
|
||
// 内部加工配置
|
||
LogFile string
|
||
}
|
||
|
||
func GetConfig() *Config {
|
||
return _initConfig(nil)
|
||
}
|
||
|
||
func _initConfig(args []string) *Config {
|
||
if config == nil {
|
||
once.Do(func() {
|
||
c, err := loadConfig(args)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
config = c
|
||
})
|
||
}
|
||
return config
|
||
}
|
||
|
||
// 获取配置文件的绝对路径
|
||
func getConfigFile(confFile string) string {
|
||
if confFile != "" {
|
||
if _, err := os.Stat(confFile); err == nil {
|
||
return confFile
|
||
}
|
||
}
|
||
// 方案:优先从工作目录查找(适配 GoLand 直接运行的场景)
|
||
workDir, err := os.Getwd()
|
||
if err == nil {
|
||
workDirConfig := filepath.Join(workDir, "output/conf/conf.yml")
|
||
if _, err := os.Stat(workDirConfig); err == nil {
|
||
return workDirConfig
|
||
}
|
||
workDirConfig = filepath.Join(workDir, "conf/conf.yml")
|
||
if _, err := os.Stat(workDirConfig); err == nil {
|
||
return workDirConfig
|
||
}
|
||
if strings.Contains(workDirConfig, "starter") {
|
||
values := strings.Split(workDirConfig, "/starter")
|
||
if len(values) > 1 {
|
||
v := filepath.Join(values[0], "starter", "conf/conf.yml")
|
||
if _, err := os.Stat(v); err == nil {
|
||
return v
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 方案:从程序执行路径推导(适配部署环境)
|
||
exePath, err := os.Executable()
|
||
if err != nil {
|
||
panic(fmt.Sprintf("获取程序路径失败: %v", err))
|
||
}
|
||
exeDir := filepath.Dir(exePath)
|
||
// 判断是否是部署环境(bin 目录下)
|
||
if filepath.Base(exeDir) == "bin" {
|
||
// 部署环境:output/bin/ -> output/conf/conf.yml
|
||
return filepath.Join(exeDir, "../conf/conf.yml")
|
||
}
|
||
|
||
// 兜底:程序目录下的 conf/conf.yaml
|
||
return filepath.Join(exeDir, "conf/conf.yml")
|
||
}
|
||
|
||
// 加载配置文件
|
||
func loadConfig(args []string) (*Config, error) {
|
||
confFile := getValueFromArgs(args, "config")
|
||
configFile := getConfigFile(confFile)
|
||
fmt.Println("configFile:", configFile)
|
||
// 读取配置文件内容
|
||
data, err := os.ReadFile(configFile)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取配置文件失败: %v", err)
|
||
}
|
||
|
||
// 解析 YAML 到结构体
|
||
var c Config
|
||
if err := yaml.Unmarshal(data, &c); err != nil {
|
||
return nil, fmt.Errorf("解析 YAML 失败: %v", err)
|
||
}
|
||
prefix := ""
|
||
if !strings.Contains(configFile, "/output/conf/") {
|
||
prefix = "/output"
|
||
}
|
||
pathArr := strings.Split(configFile, "/conf/")
|
||
if len(pathArr) == 0 {
|
||
panic("配置文件需要在`conf`目录下")
|
||
}
|
||
logDir := pathArr[0] + prefix + "/logs"
|
||
fmt.Println("logDir:", logDir)
|
||
// 自动创建日志目录(关键:避免目录不存在导致写入失败)
|
||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||
panic(fmt.Sprintf("创建日志目录失败: %v", err))
|
||
}
|
||
c.LogFile = filepath.Join(logDir, "starter.log")
|
||
return &c, nil
|
||
}
|
||
|
||
func getValueFromArgs(args []string, key string) string {
|
||
for _, arg := range args {
|
||
if strings.HasPrefix(arg, fmt.Sprintf("%s=", key)) {
|
||
value := strings.ReplaceAll(arg, fmt.Sprintf("%s=", key), "")
|
||
return value
|
||
}
|
||
}
|
||
return ""
|
||
}
|