diff --git a/yatgbot/messagequeue/heap.go b/yatgbot/messagequeue/heap.go index 7a2daa0..280d9b7 100644 --- a/yatgbot/messagequeue/heap.go +++ b/yatgbot/messagequeue/heap.go @@ -60,7 +60,7 @@ func (j MessageJob) Execute( var result tg.UpdatesBox - err := dispatcher.Client.Invoke(ctx, j.Request, &result) + err := dispatcher.client.Invoke(ctx, j.Request, &result) if err != nil { yaErr = yaerrors.FromError( http.StatusInternalServerError, diff --git a/yatgbot/messagequeue/messagequeue.go b/yatgbot/messagequeue/messagequeue.go index 7a10773..1f921df 100644 --- a/yatgbot/messagequeue/messagequeue.go +++ b/yatgbot/messagequeue/messagequeue.go @@ -2,25 +2,30 @@ package messagequeue import ( "context" + "fmt" "math/rand" + "net/http" "sync" "time" + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" "github.com/YaCodeDev/GoYaCodeDevUtils/yatgclient" "github.com/YaCodeDev/GoYaCodeDevUtils/yatgmessageencoding" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgstorage" "github.com/gotd/td/bin" "github.com/gotd/td/tg" ) // Dispatcher handles message sending with priority and concurrency control. type Dispatcher struct { - Client *yatgclient.Client + client *yatgclient.Client + storage yatgstorage.IStorage messageQueueChannel chan MessageJob heap messageHeap cond sync.Cond - log yalogger.Logger parseMode yatgmessageencoding.MessageEncoding + log yalogger.Logger } // NewDispatcher creates a new Dispatcher with the given number of workers. @@ -30,17 +35,19 @@ type Dispatcher struct { // // Example usage: // -// dispatcher := NewDispatcher(ctx, 5, log) +// dispatcher := NewDispatcher(ctx, client, storage, workerCount, parseMode, log) func NewDispatcher( ctx context.Context, client *yatgclient.Client, + storage yatgstorage.IStorage, workerCount uint, parseMode yatgmessageencoding.MessageEncoding, log yalogger.Logger, ) *Dispatcher { dispatcher := &Dispatcher{ parseMode: parseMode, - Client: client, + client: client, + storage: storage, messageQueueChannel: make(chan MessageJob), log: log, heap: newMessageHeap(), @@ -145,7 +152,7 @@ func (d *Dispatcher) AddEmptyJob(count uint) { // // Example usage: // -// jobID, resultCh := dispatcher.AddForwardMessagesJob(messagesForwardMessagesRequest, priority) +// jobID, resultCh := dispatcher.AddForwardMessagesJob(ctx, messagesForwardMessagesRequest, priority) // // // Wait for the job result // result := <-resultCh @@ -154,9 +161,31 @@ func (d *Dispatcher) AddEmptyJob(count uint) { // // Handle job error // } func (d *Dispatcher) AddForwardMessagesJob( + ctx context.Context, req *tg.MessagesForwardMessagesRequest, priority uint16, ) (uint64, <-chan JobResult) { + fromPeer, err := d.ensurePeerAccessHash(ctx, req.FromPeer) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.FromPeer = fromPeer + + toPeer, err := d.ensurePeerAccessHash(ctx, req.ToPeer) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.ToPeer = toPeer + + sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.SendAs = sendAsPeer + req.RandomID = make([]int64, len(req.ID)) for i := range req.RandomID { req.RandomID[i] = rand.Int63() @@ -169,7 +198,7 @@ func (d *Dispatcher) AddForwardMessagesJob( // // Example usage: // -// jobID, resultCh := dispatcher.AddSendMessageJob(messagesSendMessageRequest, priority) +// jobID, resultCh := dispatcher.AddSendMessageJob(ctx, messagesSendMessageRequest, priority) // // // Wait for the job result // result := <-resultCh @@ -178,16 +207,26 @@ func (d *Dispatcher) AddForwardMessagesJob( // // Handle job error // } func (d *Dispatcher) AddSendMessageJob( + ctx context.Context, req *tg.MessagesSendMessageRequest, priority uint16, ) (uint64, <-chan JobResult) { - var ( - message string - entities []tg.MessageEntityClass - ) + peer, err := d.ensurePeerAccessHash(ctx, req.Peer) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.Peer = peer + + sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.SendAs = sendAsPeer if d.parseMode != nil { - message, entities = d.parseMode.Parse(req.Message) + message, entities := d.parseMode.Parse(req.Message) req.Message = message req.Entities = entities @@ -204,7 +243,7 @@ func (d *Dispatcher) AddSendMessageJob( // // Example usage: // -// jobID, resultCh := dispatcher.AddSendMultiMediaJob(messagesSendMediaRequest, priority) +// jobID, resultCh := dispatcher.AddSendMultiMediaJob(ctx, messagesSendMediaRequest, priority) // // // Wait for the job result // result := <-resultCh @@ -213,17 +252,27 @@ func (d *Dispatcher) AddSendMessageJob( // // Handle job error // } func (d *Dispatcher) AddSendMultiMediaJob( + ctx context.Context, req *tg.MessagesSendMultiMediaRequest, priority uint16, ) (uint64, <-chan JobResult) { - var ( - message string - entities []tg.MessageEntityClass - ) + preparedPeer, err := d.ensurePeerAccessHash(ctx, req.Peer) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.Peer = preparedPeer + + sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.SendAs = sendAsPeer if d.parseMode != nil { for i, media := range req.MultiMedia { - message, entities = d.parseMode.Parse(media.Message) + message, entities := d.parseMode.Parse(media.Message) media.Message = message media.Entities = entities @@ -243,7 +292,7 @@ func (d *Dispatcher) AddSendMultiMediaJob( // // Example usage: // -// jobID, resultCh := dispatcher.AddSendMediaJob(messagesSendMediaRequest, priority) +// jobID, resultCh := dispatcher.AddSendMediaJob(ctx, messagesSendMediaRequest, priority) // // // Wait for the job result // @@ -253,16 +302,26 @@ func (d *Dispatcher) AddSendMultiMediaJob( // // Handle job error // } func (d *Dispatcher) AddSendMediaJob( + ctx context.Context, req *tg.MessagesSendMediaRequest, priority uint16, ) (uint64, <-chan JobResult) { - var ( - message string - entities []tg.MessageEntityClass - ) + preparedPeer, err := d.ensurePeerAccessHash(ctx, req.Peer) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.Peer = preparedPeer + + sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs) + if err != nil { + return 0, returnErrorJobResult(err) + } + + req.SendAs = sendAsPeer if d.parseMode != nil { - message, entities = d.parseMode.Parse(req.Message) + message, entities := d.parseMode.Parse(req.Message) req.Message = message req.Entities = entities @@ -319,3 +378,70 @@ func (d *Dispatcher) worker(ctx context.Context, id uint) { } } } + +// ensurePeerAccessHash checks if the given peer has an access hash. +// If not, it retrieves the access hash from storage and updates the peer. +func (d *Dispatcher) ensurePeerAccessHash( + ctx context.Context, + peer tg.InputPeerClass, +) (tg.InputPeerClass, yaerrors.Error) { + bot, err := d.client.Self(ctx) + if err != nil { + return nil, yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to get self for preparePeer", + ) + } + + switch p := peer.(type) { + case *tg.InputPeerUser: + if p.AccessHash == 0 || p.UserID != 0 { + accessHash, err := d.storage.GetUserAccessHash(ctx, bot.ID, p.UserID) + if err != nil { + return nil, err.Wrap("failed to get user access hash for ensurePeerAccessHash") + } + + p.AccessHash = accessHash + + return p, nil + } + + return peer, nil + case *tg.InputPeerChannel: + if p.AccessHash == 0 || p.ChannelID != 0 { + accessHash, found, err := d.storage.GetChannelAccessHash(ctx, bot.ID, p.ChannelID) + if err != nil { + return nil, err.Wrap("failed to get channel access hash for ensurePeerAccessHash") + } + + if !found { + return nil, yaerrors.FromString( + http.StatusNotFound, + fmt.Sprintf("access hash for channel %d not found", p.ChannelID), + ) + } + + p.AccessHash = accessHash + + return p, nil + } + + return peer, nil + default: + return peer, nil + } +} + +// returnErrorJobResult creates a channel that immediately returns a JobResult +// with the given error and then closes the channel. +func returnErrorJobResult(err yaerrors.Error) <-chan JobResult { + ch := make(chan JobResult, 1) + ch <- JobResult{ + Err: err, + } + + close(ch) + + return ch +} diff --git a/yatgbot/yatgbot.go b/yatgbot/yatgbot.go index ad0199b..adc90cc 100644 --- a/yatgbot/yatgbot.go +++ b/yatgbot/yatgbot.go @@ -120,6 +120,7 @@ func InitYaTgBot( msgDispatcher := messagequeue.NewDispatcher( ctx, client, + stateStorage, options.MessageQueueRatePerSecond, options.ParseMode, options.Log,