Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow configuring a custom marshaller #10

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ type amqpChannel struct {

// connectionType defines the connectionType.
connectionType connectionType

// marshaller defines the marshalling method used to encode messages.
marshaller Marshaller
}

// newConsumerChannel instantiates a new consumerChannel and amqpChannel for method inheritance.
Expand All @@ -87,13 +90,15 @@ type amqpChannel struct {
// - consumer is the MessageConsumer that will hold consumption information.
// - maxRetry is the retry header for each message.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newConsumerChannel(
ctx context.Context,
connection *amqp.Connection,
keepAlive bool,
retryDelay time.Duration,
consumer *MessageConsumer,
logger logger,
marshaller Marshaller,
) *amqpChannel {
channel := &amqpChannel{
ctx: ctx,
Expand All @@ -119,6 +124,7 @@ func newConsumerChannel(
connectionType: connectionTypeConsumer,
consumptionHealth: make(consumptionHealth),
consumer: consumer,
marshaller: marshaller,
}

// We open an initial channel.
Expand All @@ -141,6 +147,7 @@ func newConsumerChannel(
// - publishingCacheSize is the maximum cache size of failed publishing.
// - publishingCacheTTL defines the time to live for each failed publishing that was put in cache.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newPublishingChannel(
ctx context.Context,
connection *amqp.Connection,
Expand All @@ -150,6 +157,7 @@ func newPublishingChannel(
publishingCacheSize uint64,
publishingCacheTTL time.Duration,
logger logger,
marshaller Marshaller,
) *amqpChannel {
channel := &amqpChannel{
ctx: ctx,
Expand All @@ -171,6 +179,7 @@ func newPublishingChannel(
connectionType: connectionTypePublisher,
publishingCache: newTTLMap[string, mqttPublishing](publishingCacheSize, publishingCacheTTL),
maxRetry: maxRetry,
marshaller: marshaller,
}

// We open an initial channel.
Expand Down Expand Up @@ -521,7 +530,7 @@ func (c *amqpChannel) retryDelivery(delivery *amqp.Delivery, alreadyAcknowledged

// We create a new publishing which is a copy of the old one but with a decremented xDeathCountHeader.
newPublishing := amqp.Publishing{
ContentType: "application/json",
ContentType: c.marshaller.ContentType(),
Body: delivery.Body,
Type: delivery.RoutingKey,
Priority: delivery.Priority,
Expand Down Expand Up @@ -554,7 +563,7 @@ func (c *amqpChannel) retryDelivery(delivery *amqp.Delivery, alreadyAcknowledged
// publish will publish a message with the given configuration.
func (c *amqpChannel) publish(exchange string, routingKey string, payload []byte, options *PublishingOptions) error {
publishing := &amqp.Publishing{
ContentType: "application/json",
ContentType: c.marshaller.ContentType(),
Body: payload,
Type: routingKey,
Priority: PriorityMedium.Uint8(),
Expand Down
5 changes: 5 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ func newClientFromOptions(options *ClientOptions) MQTTClient {
protocol = securedProtocol
}

if options.Marshaller == nil {
options.Marshaller = defaultMarshaller
}

dialURL := fmt.Sprintf("%s://%s:%s@%s:%d/%s", protocol, client.Username, client.Password, client.Host, client.Port, client.Vhost)

client.connectionManager = newConnectionManager(
Expand All @@ -163,6 +167,7 @@ func newClientFromOptions(options *ClientOptions) MQTTClient {
options.PublishingCacheSize,
options.PublishingCacheTTL,
client.logger,
options.Marshaller,
)

return client
Expand Down
11 changes: 11 additions & 0 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type ClientOptions struct {

// Mode will specify whether logs are enabled or not.
Mode string

// Marshaller defines the content type used for messages and how they're marshalled (default: JSON).
Marshaller Marshaller
}

// DefaultClientOptions will return a ClientOptions with default values.
Expand All @@ -63,6 +66,7 @@ func DefaultClientOptions() *ClientOptions {
PublishingCacheTTL: defaultPublishingCacheTTL,
PublishingCacheSize: defaultPublishingCacheSize,
Mode: defaultMode,
Marshaller: defaultMarshaller,
}
}

Expand Down Expand Up @@ -195,3 +199,10 @@ func (c *ClientOptions) SetMode(mode string) *ClientOptions {

return c
}

// SetMarshaller will assign the Marshaller.
func (c *ClientOptions) SetMarshaller(marshaller Marshaller) *ClientOptions {
c.Marshaller = marshaller

return c
}
m3talux marked this conversation as resolved.
Show resolved Hide resolved
30 changes: 25 additions & 5 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type amqpConnection struct {

// connectionType defines the connectionType.
connectionType connectionType

// marshaller defines the marshalling method used to encode messages.
marshaller Marshaller
}

// newConsumerConnection initializes a new consumer amqpConnection with given arguments.
Expand All @@ -57,8 +60,17 @@ type amqpConnection struct {
// - keepAlive will keep the connection alive if true.
// - retryDelay defines the delay between each re-connection, if the keepAlive flag is set to true.
// - logger is the parent logger.
func newConsumerConnection(ctx context.Context, uri, connectionName string, keepAlive bool, retryDelay time.Duration, logger logger) *amqpConnection {
return newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypeConsumer)
// - marshaller is the Marshaller used for encoding messages.
func newConsumerConnection(
ctx context.Context,
uri string,
connectionName string,
keepAlive bool,
retryDelay time.Duration,
logger logger,
marshaller Marshaller,
) *amqpConnection {
return newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypeConsumer, marshaller)
}

// newPublishingConnection initializes a new publisher amqpConnection with given arguments.
Expand All @@ -71,6 +83,7 @@ func newConsumerConnection(ctx context.Context, uri, connectionName string, keep
// - publishingCacheSize defines the maximum length of failed publishing cache.
// - publishingCacheTTL defines the time to live for failed publishing in cache.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newPublishingConnection(
ctx context.Context,
uri string,
Expand All @@ -81,8 +94,9 @@ func newPublishingConnection(
publishingCacheSize uint64,
publishingCacheTTL time.Duration,
logger logger,
marshaller Marshaller,
) *amqpConnection {
conn := newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypePublisher)
conn := newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypePublisher, marshaller)

conn.maxRetry = maxRetry
conn.publishingCacheSize = publishingCacheSize
Expand All @@ -98,6 +112,7 @@ func newPublishingConnection(
// - keepAlive will keep the connection alive if true.
// - retryDelay defines the delay between each re-connection, if the keepAlive flag is set to true.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newConnection(
ctx context.Context,
uri string,
Expand All @@ -106,6 +121,7 @@ func newConnection(
retryDelay time.Duration,
logger logger,
connectionType connectionType,
marshaller Marshaller,
) *amqpConnection {
conn := &amqpConnection{
ctx: ctx,
Expand All @@ -119,6 +135,7 @@ func newConnection(
"type": connectionType,
}),
connectionType: connectionType,
marshaller: marshaller,
}

conn.logger.Debug("Initializing new amqp connection", logField{Key: "uri", Value: conn.uriForLog()})
Expand Down Expand Up @@ -303,7 +320,7 @@ func (a *amqpConnection) registerConsumer(consumer MessageConsumer) error {
return err
}

channel := newConsumerChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, &consumer, a.logger)
channel := newConsumerChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, &consumer, a.logger, a.marshaller)

a.channels = append(a.channels, channel)

Expand All @@ -315,7 +332,10 @@ func (a *amqpConnection) registerConsumer(consumer MessageConsumer) error {
func (a *amqpConnection) publish(exchange, routingKey string, payload []byte, options *PublishingOptions) error {
publishingChannel := a.channels.publishingChannel()
if publishingChannel == nil {
publishingChannel = newPublishingChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, a.maxRetry, a.publishingCacheSize, a.publishingCacheTTL, a.logger)
publishingChannel = newPublishingChannel(
a.ctx, a.connection, a.keepAlive, a.retryDelay, a.maxRetry,
a.publishingCacheSize, a.publishingCacheTTL, a.logger, a.marshaller,
)

a.channels = append(a.channels, publishingChannel)
}
Expand Down
17 changes: 13 additions & 4 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gorabbit

import (
"context"
"encoding/json"
"time"
)

Expand All @@ -12,6 +11,9 @@ type connectionManager struct {

// publisherConnection holds the independent publishing connection.
publisherConnection *amqpConnection

// marshaller holds the marshaller used to encode messages.
marshaller Marshaller
}

// newConnectionManager instantiates a new connectionManager with given arguments.
Expand All @@ -25,10 +27,17 @@ func newConnectionManager(
publishingCacheSize uint64,
publishingCacheTTL time.Duration,
logger logger,
marshaller Marshaller,
) *connectionManager {
c := &connectionManager{
consumerConnection: newConsumerConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger),
publisherConnection: newPublishingConnection(ctx, uri, connectionName, keepAlive, retryDelay, maxRetry, publishingCacheSize, publishingCacheTTL, logger),
consumerConnection: newConsumerConnection(
ctx, uri, connectionName, keepAlive, retryDelay, logger, marshaller,
),
publisherConnection: newPublishingConnection(
ctx, uri, connectionName, keepAlive, retryDelay, maxRetry,
publishingCacheSize, publishingCacheTTL, logger, marshaller,
),
marshaller: marshaller,
}

return c
Expand Down Expand Up @@ -75,7 +84,7 @@ func (c *connectionManager) publish(exchange, routingKey string, payload interfa
return errPublisherConnectionNotInitialized
}

payloadBytes, err := json.Marshal(payload)
payloadBytes, err := c.marshaller.Marshal(payload)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const (
defaultMode = Release
)

var defaultMarshaller = NewJSONMarshaller()

// Default values for the amqp Config.
const (
defaultHeartbeat = 10 * time.Second
Expand Down
12 changes: 10 additions & 2 deletions manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ type mqttManager struct {

// channel holds the single channel from the connection.
channel *amqp.Channel

// marshaller holds the marshaller used to encode messages.
marshaller Marshaller
}

// NewManager will instantiate a new MQTTManager.
Expand Down Expand Up @@ -158,6 +161,11 @@ func newManagerFromOptions(options *ManagerOptions) (MQTTManager, error) {
protocol = securedProtocol
}

if options.Marshaller == nil {
options.Marshaller = defaultMarshaller
}
manager.marshaller = options.Marshaller

dialURL := fmt.Sprintf("%s://%s:%s@%s:%d/%s", protocol, manager.Username, manager.Password, manager.Host, manager.Port, manager.Vhost)

var err error
Expand Down Expand Up @@ -320,14 +328,14 @@ func (manager *mqttManager) PushMessageToExchange(exchange, routingKey string, p
}

// We convert the payload to a []byte.
payloadBytes, err := json.Marshal(payload)
payloadBytes, err := manager.marshaller.Marshal(payload)
if err != nil {
return err
}

// We build the amqp.Publishing object.
publishing := amqp.Publishing{
ContentType: "application/json",
ContentType: manager.marshaller.ContentType(),
Body: payloadBytes,
Type: routingKey,
Priority: PriorityMedium.Uint8(),
Expand Down
25 changes: 18 additions & 7 deletions manager_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@ type ManagerOptions struct {

// Mode will specify whether logs are enabled or not.
Mode string

// Marshaller defines the content type used for messages and how they're marshalled (default: JSON).
Marshaller Marshaller
}

// DefaultManagerOptions will return a ManagerOptions with default values.
func DefaultManagerOptions() *ManagerOptions {
return &ManagerOptions{
Host: defaultHost,
Port: defaultPort,
Username: defaultUsername,
Password: defaultPassword,
Vhost: defaultVhost,
UseTLS: defaultUseTLS,
Mode: defaultMode,
Host: defaultHost,
Port: defaultPort,
Username: defaultUsername,
Password: defaultPassword,
Vhost: defaultVhost,
UseTLS: defaultUseTLS,
Mode: defaultMode,
Marshaller: defaultMarshaller,
}
}

Expand Down Expand Up @@ -126,3 +130,10 @@ func (m *ManagerOptions) SetMode(mode string) *ManagerOptions {

return m
}

// SetMarshaller will assign the Marshaller.
func (m *ManagerOptions) SetMarshaller(marshaller Marshaller) *ManagerOptions {
m.Marshaller = marshaller

return m
}
m3talux marked this conversation as resolved.
Show resolved Hide resolved
Loading