diff --git a/cmd/quota/quota.go b/cmd/quota/quota.go index e42a51b..dfea133 100644 --- a/cmd/quota/quota.go +++ b/cmd/quota/quota.go @@ -20,19 +20,22 @@ var QuotaCmd = &cobra.Command{ err := p.Login() if err != nil { logrus.Errorln("Login Failed:", err) + return // 加上这行,否则会继续执行 } q, err := p.GetQuota() if err != nil { logrus.Errorln("get cloud quota error:", err) return } + fmt.Println("Storage:") fmt.Printf("%-20s%-20s\n", "total", "used") - switch human { - case true: - fmt.Printf("%-20s%-20s\n", displayStorage(q.Limit), displayStorage(q.Usage)) - case false: - fmt.Printf("%-20s%-20s\n", q.Limit, q.Usage) + if human { + fmt.Printf("%-20s%-20s\n", displayStorage(q.Quota.Limit), displayStorage(q.Quota.Usage)) + } else { + fmt.Printf("%-20s%-20s\n", q.Quota.Limit, q.Quota.Usage) } + + displayCloudDownload(q.Quotas.CloudDownload) }, } @@ -42,15 +45,22 @@ func init() { func displayStorage(s string) string { size, _ := strconv.ParseFloat(s, 64) cnt := 0 - for size > 1024 { + for size >= 1024 { cnt += 1 if cnt > 5 { break } size /= 1024 } - // res := strconv.Itoa(int(size)) - res := strconv.FormatFloat(size, 'g', 2, 64) + + var res string + // 如果是整数则不显示小数点 + if size == float64(int64(size)) { + res = strconv.FormatFloat(size, 'f', 0, 64) + } else { + res = strconv.FormatFloat(size, 'f', 2, 64) + } + switch cnt { case 0: res += "B" @@ -67,3 +77,14 @@ func displayStorage(s string) string { } return res } + +func displayCloudDownload(cloudDownload pikpak.Quota) { + fmt.Printf("\ncloud download:\n") + fmt.Printf("%-20s%-20s%-20s\n", "total", "used", "remaining") + remaining, err := cloudDownload.Remaining() + if err != nil { + fmt.Printf("%-20s%-20s%-20s\n", cloudDownload.Limit, cloudDownload.Usage, "N/A") + return + } + fmt.Printf("%-20s%-20s%-20d\n", cloudDownload.Limit, cloudDownload.Usage, remaining) +} diff --git a/internal/pikpak/pikpak.go b/internal/pikpak/pikpak.go index 99c1269..ff160d7 100644 --- a/internal/pikpak/pikpak.go +++ b/internal/pikpak/pikpak.go @@ -37,11 +37,11 @@ func NewPikPak(account, password string) PikPak { }, } if conf.Config.UseProxy() { - url, err := url.Parse(conf.Config.Proxy) + proxyUrl, err := url.Parse(conf.Config.Proxy) if err != nil { logrus.Errorln("url parse proxy error", err) } - p := http.ProxyURL(url) + p := http.ProxyURL(proxyUrl) client.Transport = &http.Transport{ Proxy: p, } @@ -58,7 +58,8 @@ func NewPikPak(account, password string) PikPak { } } -func (p *PikPak) Login() error { +// login 执行完整登录流程 +func (p *PikPak) login() error { captchaToken, err := p.getCaptchaToken() if err != nil { return err @@ -147,3 +148,23 @@ func (p *PikPak) setHeader(req *http.Request) { req.Header.Set("User-Agent", userAgent) req.Header.Set("X-Device-Id", p.DeviceId) } + +// Login 优先复用本地 session,必要时才走完整登录 +func (p *PikPak) Login() error { + if err := p.loadSession(); err == nil { + if !p.isTokenExpired() { + logrus.Debugln("session valid, skip login") + return nil + } + logrus.Debugln("access_token expired, trying refresh_token") + if err = p.RefreshToken(); err == nil { + return p.saveSession() + } + logrus.Debugln("refresh failed, fallback to full login") + } + if err := p.login(); err != nil { + return err + } + // 执行了完整登录流程,保存session + return p.saveSession() +} diff --git a/internal/pikpak/quota.go b/internal/pikpak/quota.go index 1c566f8..5a26ceb 100644 --- a/internal/pikpak/quota.go +++ b/internal/pikpak/quota.go @@ -2,11 +2,12 @@ package pikpak import ( "net/http" + "strconv" jsoniter "github.com/json-iterator/go" ) -type quotaMessage struct { +type QuotaMessage struct { Kind string `json:"kind"` Quota Quota `json:"quota"` ExpiresAt string `json:"expires_at"` @@ -21,23 +22,37 @@ type Quota struct { PlayTimesUsage string `json:"play_times_usage"` } +// Remaining 剩余额度 +func (q Quota) Remaining() (int64, error) { + limit, err := strconv.ParseInt(q.Limit, 10, 64) + if err != nil { + return 0, err + } + usage, err := strconv.ParseInt(q.Usage, 10, 64) + if err != nil { + return 0, err + } + return limit - usage, nil +} + type Quotas struct { + CloudDownload Quota `json:"cloud_download"` } -// get cloud quota -func (p *PikPak) GetQuota() (Quota, error) { +// GetQuota get cloud quota +func (p *PikPak) GetQuota() (QuotaMessage, error) { req, err := http.NewRequest("GET", "https://api-drive.mypikpak.com/drive/v1/about", nil) if err != nil { - return Quota{}, err + return QuotaMessage{}, err } bs, err := p.sendRequest(req) if err != nil { - return Quota{}, err + return QuotaMessage{}, err } - var quotaMessage quotaMessage + var quotaMessage QuotaMessage err = jsoniter.Unmarshal(bs, "aMessage) if err != nil { - return Quota{}, err + return QuotaMessage{}, err } - return quotaMessage.Quota, nil + return quotaMessage, nil } diff --git a/internal/pikpak/refresh_token.go b/internal/pikpak/refresh_token.go index ced2acc..5bef698 100644 --- a/internal/pikpak/refresh_token.go +++ b/internal/pikpak/refresh_token.go @@ -35,7 +35,7 @@ func (p *PikPak) RefreshToken() error { // refresh token failed if error_code == 4126 { // 重新登录 - return p.Login() + return p.login() } return fmt.Errorf("refresh token error message: %d", gjson.GetBytes(bs, "error").Int()) } diff --git a/internal/pikpak/session.go b/internal/pikpak/session.go new file mode 100644 index 0000000..95e28af --- /dev/null +++ b/internal/pikpak/session.go @@ -0,0 +1,87 @@ +package pikpak + +import ( + "crypto/md5" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/sirupsen/logrus" +) + +// sessionData 是持久化到磁盘的数据结构 +type sessionData struct { + JwtToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Sub string `json:"sub"` + // ExpiresAt 是 access_token 的过期 Unix 时间戳(秒) + ExpiresAt int64 `json:"expires_at"` +} + +// saveSession 将当前 token 信息持久化到本地文件 +func (p *PikPak) saveSession() error { + path, err := sessionFile(p.Account) + if err != nil { + return err + } + data := sessionData{ + JwtToken: p.JwtToken, + RefreshToken: p.refreshToken, + Sub: p.Sub, + // RefreshSecond 是服务端返回的 expires_in(秒),提前 5 分钟视为过期 + ExpiresAt: time.Now().Unix() + p.RefreshSecond - 300, + } + + bs, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("marshal session error: %w", err) + } + if err = os.WriteFile(path, bs, 0600); err != nil { + return fmt.Errorf("write session file error: %w", err) + } + logrus.Debugln("session saved to", path) + return nil +} + +// loadSession 从本地文件加载 token,并写回到 PikPak 实例 +// 如果文件不存在或账号不匹配,返回 error +func (p *PikPak) loadSession() error { + path, err := sessionFile(p.Account) + if err != nil { + return err + } + bs, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read session file error: %w", err) + } + var data sessionData + if err = json.Unmarshal(bs, &data); err != nil { + return fmt.Errorf("unmarshal session error: %w", err) + } + + p.JwtToken = data.JwtToken + p.refreshToken = data.RefreshToken + p.Sub = data.Sub + p.RefreshSecond = data.ExpiresAt - time.Now().Unix() + logrus.Debugln("session loaded from", path) + return nil +} + +// isTokenExpired 判断 access_token 是否已过期(或即将过期) +// RefreshSecond 在 loadSession 后表示距过期的剩余秒数 +func (p *PikPak) isTokenExpired() bool { + return p.RefreshSecond <= 0 +} + +func sessionFile(account string) (string, error) { + configDir, err := os.UserConfigDir() + if err != nil { + return "", fmt.Errorf("get config dir error: %w", err) + } + hash := md5.Sum([]byte(account)) + filename := fmt.Sprintf("session_%s.json", hex.EncodeToString(hash[:])) + return filepath.Join(configDir, "pikpakcli", filename), nil +}