Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion yatgbot/messagequeue/heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (j MessageJob) Execute(

var result tg.UpdatesBox

err := dispatcher.Client.Invoke(ctx, j.Request, &result)
err := dispatcher.client.Invoke(ctx, j.Request, &result)
if err != nil {
yaErr = yaerrors.FromError(
http.StatusInternalServerError,
Expand Down
172 changes: 149 additions & 23 deletions yatgbot/messagequeue/messagequeue.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,30 @@ package messagequeue

import (
"context"
"fmt"
"math/rand"
"net/http"
"sync"
"time"

"github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors"
"github.com/YaCodeDev/GoYaCodeDevUtils/yalogger"
"github.com/YaCodeDev/GoYaCodeDevUtils/yatgclient"
"github.com/YaCodeDev/GoYaCodeDevUtils/yatgmessageencoding"
"github.com/YaCodeDev/GoYaCodeDevUtils/yatgstorage"
"github.com/gotd/td/bin"
"github.com/gotd/td/tg"
)

// Dispatcher handles message sending with priority and concurrency control.
type Dispatcher struct {
Client *yatgclient.Client
client *yatgclient.Client
storage yatgstorage.IStorage
messageQueueChannel chan MessageJob
heap messageHeap
cond sync.Cond
log yalogger.Logger
parseMode yatgmessageencoding.MessageEncoding
log yalogger.Logger
}

// NewDispatcher creates a new Dispatcher with the given number of workers.
Expand All @@ -30,17 +35,19 @@ type Dispatcher struct {
//
// Example usage:
//
// dispatcher := NewDispatcher(ctx, 5, log)
// dispatcher := NewDispatcher(ctx, client, storage, workerCount, parseMode, log)
func NewDispatcher(
ctx context.Context,
client *yatgclient.Client,
storage yatgstorage.IStorage,
workerCount uint,
parseMode yatgmessageencoding.MessageEncoding,
log yalogger.Logger,
) *Dispatcher {
dispatcher := &Dispatcher{
parseMode: parseMode,
Client: client,
client: client,
storage: storage,
messageQueueChannel: make(chan MessageJob),
log: log,
heap: newMessageHeap(),
Expand Down Expand Up @@ -145,7 +152,7 @@ func (d *Dispatcher) AddEmptyJob(count uint) {
//
// Example usage:
//
// jobID, resultCh := dispatcher.AddForwardMessagesJob(messagesForwardMessagesRequest, priority)
// jobID, resultCh := dispatcher.AddForwardMessagesJob(ctx, messagesForwardMessagesRequest, priority)
//
// // Wait for the job result
// result := <-resultCh
Expand All @@ -154,9 +161,31 @@ func (d *Dispatcher) AddEmptyJob(count uint) {
// // Handle job error
// }
func (d *Dispatcher) AddForwardMessagesJob(
ctx context.Context,
req *tg.MessagesForwardMessagesRequest,
priority uint16,
) (uint64, <-chan JobResult) {
fromPeer, err := d.ensurePeerAccessHash(ctx, req.FromPeer)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.FromPeer = fromPeer

toPeer, err := d.ensurePeerAccessHash(ctx, req.ToPeer)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.ToPeer = toPeer

sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.SendAs = sendAsPeer

req.RandomID = make([]int64, len(req.ID))
for i := range req.RandomID {
req.RandomID[i] = rand.Int63()
Expand All @@ -169,7 +198,7 @@ func (d *Dispatcher) AddForwardMessagesJob(
//
// Example usage:
//
// jobID, resultCh := dispatcher.AddSendMessageJob(messagesSendMessageRequest, priority)
// jobID, resultCh := dispatcher.AddSendMessageJob(ctx, messagesSendMessageRequest, priority)
//
// // Wait for the job result
// result := <-resultCh
Expand All @@ -178,16 +207,26 @@ func (d *Dispatcher) AddForwardMessagesJob(
// // Handle job error
// }
func (d *Dispatcher) AddSendMessageJob(
ctx context.Context,
req *tg.MessagesSendMessageRequest,
priority uint16,
) (uint64, <-chan JobResult) {
var (
message string
entities []tg.MessageEntityClass
)
peer, err := d.ensurePeerAccessHash(ctx, req.Peer)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.Peer = peer

sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.SendAs = sendAsPeer

if d.parseMode != nil {
message, entities = d.parseMode.Parse(req.Message)
message, entities := d.parseMode.Parse(req.Message)

req.Message = message
req.Entities = entities
Expand All @@ -204,7 +243,7 @@ func (d *Dispatcher) AddSendMessageJob(
//
// Example usage:
//
// jobID, resultCh := dispatcher.AddSendMultiMediaJob(messagesSendMediaRequest, priority)
// jobID, resultCh := dispatcher.AddSendMultiMediaJob(ctx, messagesSendMediaRequest, priority)
//
// // Wait for the job result
// result := <-resultCh
Expand All @@ -213,17 +252,27 @@ func (d *Dispatcher) AddSendMessageJob(
// // Handle job error
// }
func (d *Dispatcher) AddSendMultiMediaJob(
ctx context.Context,
req *tg.MessagesSendMultiMediaRequest,
priority uint16,
) (uint64, <-chan JobResult) {
var (
message string
entities []tg.MessageEntityClass
)
preparedPeer, err := d.ensurePeerAccessHash(ctx, req.Peer)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.Peer = preparedPeer

sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.SendAs = sendAsPeer

if d.parseMode != nil {
for i, media := range req.MultiMedia {
message, entities = d.parseMode.Parse(media.Message)
message, entities := d.parseMode.Parse(media.Message)

media.Message = message
media.Entities = entities
Expand All @@ -243,7 +292,7 @@ func (d *Dispatcher) AddSendMultiMediaJob(
//
// Example usage:
//
// jobID, resultCh := dispatcher.AddSendMediaJob(messagesSendMediaRequest, priority)
// jobID, resultCh := dispatcher.AddSendMediaJob(ctx, messagesSendMediaRequest, priority)
//
// // Wait for the job result
//
Expand All @@ -253,16 +302,26 @@ func (d *Dispatcher) AddSendMultiMediaJob(
// // Handle job error
// }
func (d *Dispatcher) AddSendMediaJob(
ctx context.Context,
req *tg.MessagesSendMediaRequest,
priority uint16,
) (uint64, <-chan JobResult) {
var (
message string
entities []tg.MessageEntityClass
)
preparedPeer, err := d.ensurePeerAccessHash(ctx, req.Peer)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.Peer = preparedPeer

sendAsPeer, err := d.ensurePeerAccessHash(ctx, req.SendAs)
if err != nil {
return 0, returnErrorJobResult(err)
}

req.SendAs = sendAsPeer

if d.parseMode != nil {
message, entities = d.parseMode.Parse(req.Message)
message, entities := d.parseMode.Parse(req.Message)

req.Message = message
req.Entities = entities
Expand Down Expand Up @@ -319,3 +378,70 @@ func (d *Dispatcher) worker(ctx context.Context, id uint) {
}
}
}

// ensurePeerAccessHash checks if the given peer has an access hash.
// If not, it retrieves the access hash from storage and updates the peer.
func (d *Dispatcher) ensurePeerAccessHash(
ctx context.Context,
peer tg.InputPeerClass,
) (tg.InputPeerClass, yaerrors.Error) {
bot, err := d.client.Self(ctx)
if err != nil {
return nil, yaerrors.FromError(
http.StatusInternalServerError,
err,
"failed to get self for preparePeer",
)
}

switch p := peer.(type) {
case *tg.InputPeerUser:
if p.AccessHash == 0 || p.UserID != 0 {
accessHash, err := d.storage.GetUserAccessHash(ctx, bot.ID, p.UserID)
if err != nil {
return nil, err.Wrap("failed to get user access hash for ensurePeerAccessHash")
}

p.AccessHash = accessHash

return p, nil
}

return peer, nil
case *tg.InputPeerChannel:
if p.AccessHash == 0 || p.ChannelID != 0 {
accessHash, found, err := d.storage.GetChannelAccessHash(ctx, bot.ID, p.ChannelID)
if err != nil {
return nil, err.Wrap("failed to get channel access hash for ensurePeerAccessHash")
}

if !found {
return nil, yaerrors.FromString(
http.StatusNotFound,
fmt.Sprintf("access hash for channel %d not found", p.ChannelID),
)
}

p.AccessHash = accessHash

return p, nil
}

return peer, nil
default:
return peer, nil
}
}

// returnErrorJobResult creates a channel that immediately returns a JobResult
// with the given error and then closes the channel.
func returnErrorJobResult(err yaerrors.Error) <-chan JobResult {
ch := make(chan JobResult, 1)
ch <- JobResult{
Err: err,
}

close(ch)

return ch
}
1 change: 1 addition & 0 deletions yatgbot/yatgbot.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func InitYaTgBot(
msgDispatcher := messagequeue.NewDispatcher(
ctx,
client,
stateStorage,
options.MessageQueueRatePerSecond,
options.ParseMode,
options.Log,
Expand Down