Skip to content

Commit

Permalink
refs #34: transaction wrapperをリファクタリング
Browse files Browse the repository at this point in the history
  • Loading branch information
CityBear3 committed Dec 8, 2023
1 parent d1b4a2f commit 304cf4d
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 16 deletions.
6 changes: 3 additions & 3 deletions internal/adaptor/gateway/repository/mysql/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (i *ArchiveRepository) Save(
archive entity.Archive,
) error {
var exec boil.ContextExecutor
exec, ok := ctx.Value("tx").(*sql.Tx)
exec, ok := getTxFromCtx(ctx)
if !ok {
exec = i.db
}
Expand Down Expand Up @@ -59,7 +59,7 @@ func (i *ArchiveRepository) GetArchive(
archiveID primitive.ID,
) (entity.Archive, error) {
var exec boil.ContextExecutor
exec, ok := ctx.Value("tx").(*sql.Tx)
exec, ok := getTxFromCtx(ctx)
if !ok {
exec = i.db
}
Expand Down Expand Up @@ -105,7 +105,7 @@ func (i *ArchiveRepository) GetArchiveByArchiveEventID(
archiveEventID primitive.ID,
) (entity.Archive, error) {
var exec boil.ContextExecutor
exec, ok := ctx.Value("tx").(*sql.Tx)
exec, ok := getTxFromCtx(ctx)
if !ok {
exec = i.db
}
Expand Down
2 changes: 1 addition & 1 deletion internal/adaptor/gateway/repository/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewClientRepository(db boil.ContextExecutor) *ClientRepository {

func (i *ClientRepository) GetClient(ctx context.Context, clientID primitive.ID) (entity.Client, error) {
var exec boil.ContextExecutor
exec, ok := ctx.Value("tx").(*sql.Tx)
exec, ok := getTxFromCtx(ctx)
if !ok {
exec = i.db
}
Expand Down
2 changes: 1 addition & 1 deletion internal/adaptor/gateway/repository/mysql/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func NewDeviceRepository(db boil.ContextExecutor) *DeviceRepository {

func (d *DeviceRepository) GetDevice(ctx context.Context, deviceID primitive.ID) (entity.Device, error) {
var exec boil.ContextExecutor
exec, ok := ctx.Value("tx").(*sql.Tx)
exec, ok := getTxFromCtx(ctx)
if !ok {
exec = d.db
}
Expand Down
4 changes: 2 additions & 2 deletions internal/adaptor/gateway/repository/mysql/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func NewEventRepository(db boil.ContextExecutor) *EventRepository {

func (r *EventRepository) SaveArchiveEvent(ctx context.Context, archiveEvent entity.ArchiveEvent) error {
var exec boil.ContextExecutor
exec, ok := ctx.Value("tx").(*sql.Tx)
exec, ok := getTxFromCtx(ctx)
if !ok {
exec = r.db
}
Expand All @@ -45,7 +45,7 @@ func (r *EventRepository) SaveArchiveEvent(ctx context.Context, archiveEvent ent

func (r *EventRepository) GetArchiveEvent(ctx context.Context, archiveEventID primitive.ID) (entity.ArchiveEvent, error) {
var exec boil.ContextExecutor
exec, ok := ctx.Value("tx").(*sql.Tx)
exec, ok := getTxFromCtx(ctx)
if !ok {
exec = r.db
}
Expand Down
19 changes: 13 additions & 6 deletions internal/adaptor/gateway/repository/mysql/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,30 @@ func NewTxManger(db *sql.DB) *TxManager {
}
}

func (t *TxManager) DoInTx(ctx context.Context, operation usecase.Operation) error {
func (t *TxManager) DoInTx(ctx context.Context, operation usecase.Operation) (context.Context, error) {
tx, err := t.db.BeginTx(ctx, nil)
if err != nil {
return err
return ctx, err
}

ctx = context.WithValue(ctx, "tx", tx)

if err := operation(ctx); err != nil {
if err := tx.Rollback(); err != nil {
return err
return ctx, err
}
return err
return ctx, err
}

if err = tx.Commit(); err != nil {
return err
return ctx, err
}
return nil

ctx = context.WithValue(ctx, "tx", nil)
return ctx, nil
}

func getTxFromCtx(ctx context.Context) (*sql.Tx, bool) {
tx, ok := ctx.Value("tx").(*sql.Tx)
return tx, ok
}
2 changes: 1 addition & 1 deletion internal/usecase/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (i *ArchiveInteractor) CreateArchive(
request CreateArchiveInput,
device entity.Device,
) error {
if err := i.txManager.DoInTx(ctx, func(ctx2 context.Context) error {
if _, err := i.txManager.DoInTx(ctx, func(ctx2 context.Context) error {
archiveID := primitive.NewID()

event, err := i.eventRepository.GetArchiveEvent(ctx2, request.ArchiveEventID)
Expand Down
2 changes: 1 addition & 1 deletion internal/usecase/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewEventInteractor(

func (i EventInteractor) PublishArchiveEvent(ctx context.Context, client entity.Client) (primitive.ID, error) {
archiveEvent := entity.NewArchiveEvent(primitive.NewID(), client.Devices[0].ID, client.ID)
if err := i.txManager.DoInTx(ctx, func(ctx2 context.Context) error {
if _, err := i.txManager.DoInTx(ctx, func(ctx2 context.Context) error {
if err := i.eventRepository.SaveArchiveEvent(ctx2, archiveEvent); err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion internal/usecase/transaction.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:generate mockgen -source=$GOFILE -package=mock_usecase -destination=./mock/$GOFILE
package usecase

import (
Expand All @@ -7,5 +8,5 @@ import (
type Operation func(ctx context.Context) error

type ITxManager interface {
DoInTx(ctx context.Context, operation Operation) error
DoInTx(ctx context.Context, operation Operation) (context.Context, error)
}

0 comments on commit 304cf4d

Please sign in to comment.