diff --git a/coding.go b/coding.go index e202f75..e50639a 100644 --- a/coding.go +++ b/coding.go @@ -2,44 +2,53 @@ package format import ( "fmt" - "sync" - blocks "github.com/ipfs/go-block-format" ) // DecodeBlockFunc functions decode blocks into nodes. type DecodeBlockFunc func(block blocks.Block) (Node, error) -type BlockDecoder interface { - Register(codec uint64, decoder DecodeBlockFunc) - Decode(blocks.Block) (Node, error) -} -type safeBlockDecoder struct { - // Can be replaced with an RCU if necessary. - lock sync.RWMutex +// Registry is a structure for storing mappings of multicodec IPLD codec numbers to DecodeBlockFunc functions. +// +// Registry includes no mutexing. If using Registry in a concurrent context, you must handle synchronization yourself. +// (Typically, it is recommended to do initialization earlier in a program, before fanning out goroutines; +// this avoids the need for mutexing overhead.) +// +// Multicodec indicator numbers are specified in +// https://github.com/multiformats/multicodec/blob/master/table.csv . +// You should not use indicator numbers which are not specified in that table +// (however, there is nothing in this implementation that will attempt to stop you, either). +type Registry struct { decoders map[uint64]DecodeBlockFunc } +func (r *Registry) ensureInit() { + if r.decoders != nil { + return + } + r.decoders = make(map[uint64]DecodeBlockFunc) +} + // Register registers decoder for all blocks with the passed codec. // // This will silently replace any existing registered block decoders. -func (d *safeBlockDecoder) Register(codec uint64, decoder DecodeBlockFunc) { - d.lock.Lock() - defer d.lock.Unlock() - d.decoders[codec] = decoder +func (r *Registry) Register(codec uint64, decoder DecodeBlockFunc) { + r.ensureInit() + if decoder == nil { + panic("not sensible to attempt to register a nil function") + } + r.decoders[codec] = decoder } -func (d *safeBlockDecoder) Decode(block blocks.Block) (Node, error) { +func (r *Registry) Decode(block blocks.Block) (Node, error) { // Short-circuit by cast if we already have a Node. if node, ok := block.(Node); ok { return node, nil } ty := block.Cid().Type() - - d.lock.RLock() - decoder, ok := d.decoders[ty] - d.lock.RUnlock() + r.ensureInit() + decoder, ok := r.decoders[ty] if ok { return decoder(block) @@ -49,14 +58,13 @@ func (d *safeBlockDecoder) Decode(block blocks.Block) (Node, error) { } } -var DefaultBlockDecoder BlockDecoder = &safeBlockDecoder{decoders: make(map[uint64]DecodeBlockFunc)} - -// Decode decodes the given block using the default BlockDecoder. -func Decode(block blocks.Block) (Node, error) { - return DefaultBlockDecoder.Decode(block) -} +// Decode decodes the given block using passed DecodeBlockFunc. +// Note: this is just a helper function, consider using the DecodeBlockFunc itself rather than this helper +func Decode(block blocks.Block, decoder DecodeBlockFunc) (Node, error) { + // Short-circuit by cast if we already have a Node. + if node, ok := block.(Node); ok { + return node, nil + } -// Register registers block decoders with the default BlockDecoder. -func Register(codec uint64, decoder DecodeBlockFunc) { - DefaultBlockDecoder.Register(codec, decoder) + return decoder(block) } diff --git a/coding_test.go b/coding_test.go index dad8498..ce80fa9 100644 --- a/coding_test.go +++ b/coding_test.go @@ -9,17 +9,53 @@ import ( mh "github.com/multiformats/go-multihash" ) -func init() { - Register(cid.Raw, func(b blocks.Block) (Node, error) { +func TestDecode(t *testing.T) { + decoder := func(b blocks.Block) (Node, error) { node := &EmptyNode{} if b.RawData() != nil || !b.Cid().Equals(node.Cid()) { return nil, errors.New("can only decode empty blocks") } return node, nil - }) + } + + id, err := cid.Prefix{ + Version: 1, + Codec: cid.Raw, + MhType: mh.ID, + MhLength: 0, + }.Sum(nil) + + if err != nil { + t.Fatalf("failed to create cid: %s", err) + } + + block, err := blocks.NewBlockWithCid(nil, id) + if err != nil { + t.Fatalf("failed to create empty block: %s", err) + } + node, err := Decode(block, decoder) + if err != nil { + t.Fatalf("failed to decode empty node: %s", err) + } + if !node.Cid().Equals(id) { + t.Fatalf("empty node doesn't have the right cid") + } + + if _, ok := node.(*EmptyNode); !ok { + t.Fatalf("empty node doesn't have the right type") + } + } -func TestDecode(t *testing.T) { +func TestRegistryDecode(t *testing.T) { + decoder := func(b blocks.Block) (Node, error) { + node := &EmptyNode{} + if b.RawData() != nil || !b.Cid().Equals(node.Cid()) { + return nil, errors.New("can only decode empty blocks") + } + return node, nil + } + id, err := cid.Prefix{ Version: 1, Codec: cid.Raw, @@ -35,10 +71,18 @@ func TestDecode(t *testing.T) { if err != nil { t.Fatalf("failed to create empty block: %s", err) } - node, err := Decode(block) + + reg := Registry{} + _, err = reg.Decode(block) + if err == nil || err.Error() != "unrecognized object type: 85" { + t.Fatalf("expected error, got %v", err) + } + reg.Register(cid.Raw, decoder) + node, err := reg.Decode(block) if err != nil { t.Fatalf("failed to decode empty node: %s", err) } + if !node.Cid().Equals(id) { t.Fatalf("empty node doesn't have the right cid") }