dkforest

A forum and chat platform (onion)
git clone https://git.dasho.dev/n0tr1v/dkforest.git
Log | Files | Refs | LICENSE

pubsub.go (3034B)


      1 package pubsub
      2 
      3 import (
      4 	"context"
      5 	"sync"
      6 	"time"
      7 
      8 	"github.com/pkg/errors"
      9 )
     10 
     11 // PubSub contains and manage the map of topics -> subscribers
     12 type PubSub[T any] struct {
     13 	sync.Mutex
     14 	m map[string][]*Sub[T]
     15 }
     16 
     17 func NewPubSub[T any]() *PubSub[T] {
     18 	ps := PubSub[T]{}
     19 	ps.m = make(map[string][]*Sub[T])
     20 	return &ps
     21 }
     22 
     23 func (p *PubSub[T]) getSubscribers(topic string) []*Sub[T] {
     24 	p.Lock()
     25 	defer p.Unlock()
     26 	return p.m[topic]
     27 }
     28 
     29 func (p *PubSub[T]) addSubscriber(s *Sub[T]) {
     30 	p.Lock()
     31 	for _, topic := range s.topics {
     32 		p.m[topic] = append(p.m[topic], s)
     33 	}
     34 	p.Unlock()
     35 }
     36 
     37 func (p *PubSub[T]) removeSubscriber(s *Sub[T]) {
     38 	p.Lock()
     39 	for _, topic := range s.topics {
     40 		for i, subscriber := range p.m[topic] {
     41 			if subscriber == s {
     42 				p.m[topic] = append(p.m[topic][:i], p.m[topic][i+1:]...)
     43 				break
     44 			}
     45 		}
     46 	}
     47 	p.Unlock()
     48 }
     49 
     50 // Subscribe is an alias for NewSub
     51 func (p *PubSub[T]) Subscribe(topics []string) *Sub[T] {
     52 	ctx, cancel := context.WithCancel(context.Background())
     53 	s := &Sub[T]{topics: topics, ch: make(chan Payload[T], 10), ctx: ctx, cancel: cancel, p: p}
     54 	p.addSubscriber(s)
     55 	return s
     56 }
     57 
     58 // Pub shortcut for publish which ignore the error
     59 func (p *PubSub[T]) Pub(topic string, msg T) {
     60 	for _, s := range p.getSubscribers(topic) {
     61 		s.publish(Payload[T]{topic, msg})
     62 	}
     63 }
     64 
     65 type Payload[T any] struct {
     66 	Topic string
     67 	Msg   T
     68 }
     69 
     70 // ErrTimeout error returned when timeout occurs
     71 var ErrTimeout = errors.New("timeout")
     72 
     73 // ErrCancelled error returned when context is cancelled
     74 var ErrCancelled = errors.New("cancelled")
     75 
     76 // Sub subscriber will receive messages published on a Topic in his ch
     77 type Sub[T any] struct {
     78 	topics []string        // Topics subscribed to
     79 	ch     chan Payload[T] // Receives messages in this channel
     80 	ctx    context.Context
     81 	cancel context.CancelFunc
     82 	p      *PubSub[T]
     83 }
     84 
     85 // ReceiveTimeout2 returns a message received on the channel or timeout
     86 func (s *Sub[T]) ReceiveTimeout2(timeout time.Duration, c1 <-chan struct{}) (topic string, msg T, err error) {
     87 	select {
     88 	case p := <-s.ch:
     89 		return p.Topic, p.Msg, nil
     90 	case <-time.After(timeout):
     91 		return topic, msg, ErrTimeout
     92 	case <-c1:
     93 		return topic, msg, ErrCancelled
     94 	case <-s.ctx.Done():
     95 		return topic, msg, ErrCancelled
     96 	}
     97 }
     98 
     99 // ReceiveTimeout returns a message received on the channel or timeout
    100 func (s *Sub[T]) ReceiveTimeout(timeout time.Duration) (topic string, msg T, err error) {
    101 	c1 := make(chan struct{})
    102 	return s.ReceiveTimeout2(timeout, c1)
    103 }
    104 
    105 // Receive returns a message
    106 func (s *Sub[T]) Receive() (topic string, msg T, err error) {
    107 	var res T
    108 	select {
    109 	case p := <-s.ch:
    110 		return p.Topic, p.Msg, nil
    111 	case <-s.ctx.Done():
    112 		return topic, res, ErrCancelled
    113 	}
    114 }
    115 
    116 // ReceiveCh returns a message
    117 func (s *Sub[T]) ReceiveCh() <-chan Payload[T] {
    118 	return s.ch
    119 }
    120 
    121 // Close will remove the subscriber from the Topic subscribers
    122 func (s *Sub[T]) Close() {
    123 	s.cancel()
    124 	s.p.removeSubscriber(s)
    125 }
    126 
    127 // publish a message to the subscriber channel
    128 func (s *Sub[T]) publish(p Payload[T]) {
    129 	select {
    130 	case s.ch <- p:
    131 	default:
    132 	}
    133 }