Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
* Make parameters explicit in constructors; use in tests
* Add type annotations in functions
* Update constructors based on #257 and #263

Signed-off-by: Marcela Melara <marcela.melara@intel.com>
  • Loading branch information
marcelamelara committed Jul 12, 2023
1 parent 62c1e9c commit e5f2403
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 70 deletions.
4 changes: 2 additions & 2 deletions docs/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ make go_test

### Writing new Go tests

Please use the standard [Golang testing package] to write tests
Please use the standard [Go testing package] to write tests
for new predicates. For example tests, take a look at the `*_test.go`
files in the `go/` directory tree.

Expand Down Expand Up @@ -47,5 +47,5 @@ modules in the `tests/python/` directory tree.
At a minimum, we suggest testing JSON marshalling and unmarshalling
of the Python language bindings.

[Golang testing package]: https://pkg.go.dev/testing
[Go testing package]: https://pkg.go.dev/testing
[Python unittest package]: https://docs.python.org/3/library/unittest.html
27 changes: 21 additions & 6 deletions python/in_toto_attestation/v1/resource_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
# Wrapper class for in-toto attestation ResourceDescriptor protos.

import in_toto_attestation.v1.resource_descriptor_pb2 as rdpb
from google.protobuf.struct_pb2 import Value
import google.protobuf.json_format as pb_json
import json

class ResourceDescriptor:
def __init__(self, rd=None):
def __init__(self, name: str='', uri: str='', digest: dict=None, content: bytes=bytes(), download_location: str='', media_type: str='', annotations: dict=None) -> None:
self.pb = rdpb.ResourceDescriptor()
self.pb.name = name
self.pb.uri = uri
if digest:
self.pb.digest.update(digest)
self.pb.content = content
self.pb.download_location = download_location
self.pb.media_type = media_type
if annotations:
self.pb.annotations.update(annotations)

if rd:
self.pb.CopyFrom(rd)
@staticmethod
def copy_from_pb(proto: type[rdpb.ResourceDescriptor]) -> 'ResourceDescriptor':
rd = ResourceDescriptor()
rd.pb.CopyFrom(proto)
return rd

def validate(self):
def validate(self) -> None:
# at least one of name, URI or digest are required
if (not self.pb.name and not self.pb.uri and not self.pb.digest) or len(self.pb.digest) == 0:
raise ValueError("At least one of name, URI, or digest are required")
if self.pb.name == '' and self.pb.uri == '' and len(self.pb.digest) == 0:
raise ValueError("At least one of name, URI, or digest need to be set")
35 changes: 22 additions & 13 deletions python/in_toto_attestation/v1/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,40 @@
STATEMENT_TYPE_URI = 'https://in-toto.io/Statement/v1'

class Statement:
def __init__(self, s=None):
def __init__(self, subjects: list, predicate_type: str, predicate: dict) -> None:
self.pb = spb.Statement()

if s:
self.pb.CopyFrom(s)

def validate(self):
if not self.pb.type or self.pb.type != STATEMENT_TYPE_URI:
self.pb.type = STATEMENT_TYPE_URI
self.pb.subject.extend(subjects)
self.pb.predicate_type = predicate_type
self.pb.predicate.update(predicate)

@staticmethod
def copy_from_pb(proto: type[spb.Statement]) -> 'Statement':
stmt = Statement([], '', {})
stmt.pb.CopyFrom(proto)
return stmt

def validate(self) -> None:
if self.pb.type != STATEMENT_TYPE_URI:
raise ValueError('Wrong statement type')

if not self.pb.subject or len(self.pb.subject) == 0:
if len(self.pb.subject) == 0:
raise ValueError('At least one subject required')

# check all resource descriptors in the subject
subject = self.pb.subject
for rdpb in subject:
rd = ResourceDescriptor(rdpb)
for i, rdpb in enumerate(subject):
rd = ResourceDescriptor.copy_from_pb(rdpb)
rd.validate()

# v1 statements require the digest to be set in the subject
if len(rd.pb.digest) == 0:
raise ValueError('At least one digest required')
# return index in the subjects list in case of failure:
# can't assume any other fields in subject are set
raise ValueError('At least one digest required (subject {0})'.format(i))

if not self.pb.predicateType:
if self.pb.predicate_type == '':
raise ValueError('Predicate type required')

if not self.pb.predicate:
if len(self.pb.predicate) == 0:
raise ValueError('Predicate object required')
58 changes: 29 additions & 29 deletions tests/python/test_resource_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,51 @@
from in_toto_attestation.v1.resource_descriptor import ResourceDescriptor

def create_test_desc():
desc = rdpb.ResourceDescriptor()
desc.name = 'theName'
desc.uri = 'https://example.com'
desc.digest['alg'] = 'abc123'
desc.content = b'bytescontent'
desc.downloadLocation = 'https://example.com/test.zip'
desc.mediaType = 'theMediaType'
desc.annotations['a1'].update({'keyStr': 'val1', 'keyNum': 13})
desc.annotations['a2'].update({'keyObj': {'subKey': 'subVal'}})
return ResourceDescriptor(desc)
desc = ResourceDescriptor('theName', 'https://example.com', {'alg':'abc123'}, b'bytescontent', 'https://example.com/test.zip', 'theMediaType', {'keyStr': 'val1', 'keyNum': 13, 'keyObj': {'subKey': 'subVal'}})

class TestResourceDescriptor(unittest.TestCase):
def setUp(self):
self.want_full_rd = '{"name":"theName","uri":"https://example.com","digest":{"alg":"abc123"},"content":"Ynl0ZXNjb250ZW50","downloadLocation":"https://example.com/test.zip","mediaType":"theMediaType","annotations":{"a1":{"keyNum": 13,"keyStr":"val1"},"a2":{"keyObj":{"subKey":"subVal"}}}}'

self.bad_rd = '{"downloadLocation":"https://example.com/test.zip","mediaType":"theMediaType"}'

self.bad_digest = '{"name":"theName","digest":{},"downloadLocation":"https://example.com/test.zip","mediaType":"theMediaType"}'

self.test_rd = create_test_desc()
return desc

class TestResourceDescriptor(unittest.TestCase):
def test_create_resource_descriptor(self):
self.test_rd.validate()
test_rd = create_test_desc()
test_rd.validate()

def test_json_parse_resource_descriptor(self):
got_pb = pb_json.Parse(self.want_full_rd, rdpb.ResourceDescriptor())
full_rd = '{"name":"theName","uri":"https://example.com","digest":{"alg":"abc123"},"content":"Ynl0ZXNjb250ZW50","downloadLocation":"https://example.com/test.zip","mediaType":"theMediaType","annotations":{"keyNum": 13,"keyStr":"val1","keyObj":{"subKey":"subVal"}}}'
got_pb = pb_json.Parse(full_rd, rdpb.ResourceDescriptor())
got = got_pb.SerializeToString(deterministic=True)

want = self.test_rd.pb.SerializeToString(deterministic=True)
test_rd = create_test_desc()
want = test_rd.pb.SerializeToString(deterministic=True)

self.assertEqual(got, want, 'Protos do not match')

def test_bad_resource_descriptor(self):
got_pb = pb_json.Parse(self.bad_rd, rdpb.ResourceDescriptor())
got = ResourceDescriptor(got_pb)
bad_rd = '{"downloadLocation":"https://example.com/test.zip","mediaType":"theMediaType"}'

with self.assertRaises(ValueError, msg='Error: created malformed ResourceDescriptor'):
got_pb = pb_json.Parse(bad_rd, rdpb.ResourceDescriptor())
got = ResourceDescriptor.copy_from_pb(got_pb)

with self.assertRaises(ValueError, msg='Error: created malformed ResourceDescriptor (no required fields)'):
got.validate()

def test_bad_digest(self):
got_pb = pb_json.Parse(self.bad_digest, rdpb.ResourceDescriptor())
got = ResourceDescriptor(got_pb)
def test_empty_name_only(self):
bad_rd = '{"name":"","downloadLocation":"https://example.com/test.zip","mediaType":"theMediaType"}'

got_pb = pb_json.Parse(bad_rd, rdpb.ResourceDescriptor())
got = ResourceDescriptor.copy_from_pb(got_pb)

with self.assertRaises(ValueError, msg='Error: created ResourceDescriptor with malformed digest field'):
with self.assertRaises(ValueError, msg='Error: created malformed ResourceDescriptor (only empty required fields)'):
got.validate()

def test_empty_digest(self):
empty_digest = '{"name":"theName","digest":{},"downloadLocation":"https://example.com/test.zip","mediaType":"theMediaType"}'

got_pb = pb_json.Parse(empty_digest, rdpb.ResourceDescriptor())
got = ResourceDescriptor.copy_from_pb(got_pb)

# this should not raise an error
got.validate()

if __name__ == '__main__':
unittest.main()
34 changes: 14 additions & 20 deletions tests/python/test_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,36 @@
import in_toto_attestation.v1.statement_pb2 as stpb
import in_toto_attestation.v1.resource_descriptor_pb2 as rdpb

import in_toto_attestation.v1.statement as its
from in_toto_attestation.v1.statement import Statement

def create_test_statement():
sub = rdpb.ResourceDescriptor()
sub.name = 'theSub'
sub.digest['alg1'] = 'abc123'

st = stpb.Statement()
st.type = its.STATEMENT_TYPE_URI
st.subject.append(sub)
st.predicateType = 'thePredicate'
st.predicate.update({'keyObj': {'subKey': 'subVal'}})
return its.Statement(st)
stmt = Statement([sub], 'thePredicate', {'keyObj': {'subKey': 'subVal'}})
return stmt

class TestStatement(unittest.TestCase):
def setUp(self):
self.want_st = '{"_type":"https://in-toto.io/Statement/v1","subject":[{"name":"theSub","digest":{"alg1":"abc123"}}],"predicateType":"thePredicate","predicate":{"keyObj":{"subKey":"subVal"}}}'

self.test_st = create_test_statement()

def test_create_statement(self):
self.test_st.validate()
test_st = create_test_statement()
test_st.validate()

def test_json_parse_statement(self):
got_pb = pb_json.Parse(self.want_st, stpb.Statement())
full_st = '{"_type":"https://in-toto.io/Statement/v1","subject":[{"name":"theSub","digest":{"alg1":"abc123"}}],"predicateType":"thePredicate","predicate":{"keyObj":{"subKey":"subVal"}}}'
got_pb = pb_json.Parse(full_st, stpb.Statement())
got = got_pb.SerializeToString(deterministic=True)

want = self.test_st.pb.SerializeToString(deterministic=True)
test_st = create_test_statement()
want = test_st.pb.SerializeToString(deterministic=True)

self.assertEqual(got, want, 'Protos do not match')

def test_bad_statement_type(self):
bad_st = '{"_type":"https://in-toto.io/Statement/v0","subject":[{"name":"theSub","digest":{"alg1":"abc123"}}],"predicateType":"thePredicate","predicate":{"keyObj":{"subKey":"subVal"}}}'

got_pb = pb_json.Parse(bad_st, stpb.Statement())
got = its.Statement(got_pb)
got = Statement.copy_from_pb(got_pb)

with self.assertRaises(ValueError, msg='Error: created malformed Statement (bad type)'):
got.validate()
Expand All @@ -52,7 +46,7 @@ def test_bad_statement_empty_subject(self):
bad_st = '{"_type":"https://in-toto.io/Statement/v1","subject":[],"predicateType":"thePredicate","predicate":{"keyObj":{"subKey":"subVal"}}}'

got_pb = pb_json.Parse(bad_st, stpb.Statement())
got = its.Statement(got_pb)
got = Statement.copy_from_pb(got_pb)

with self.assertRaises(ValueError, msg='Error: created malformed Statement (empty subject)'):
got.validate()
Expand All @@ -61,7 +55,7 @@ def test_bad_statement_bad_subject(self):
bad_st = '{"_type":"https://in-toto.io/Statement/v1","subject":[{"digest":{}}],"predicateType":"thePredicate","predicate":{"keyObj":{"subKey":"subVal"}}}'

got_pb = pb_json.Parse(bad_st, stpb.Statement())
got = its.Statement(got_pb)
got = Statement.copy_from_pb(got_pb)

with self.assertRaises(ValueError, msg='Error: created malformed Statement (bad subject)'):
got.validate()
Expand All @@ -70,7 +64,7 @@ def test_bad_predicate_type(self):
bad_st = '{"_type":"https://in-toto.io/Statement/v1","subject":[{"name":"theSub","digest":{"alg1":"abc123"}}],"predicateType":"","predicate":{"keyObj":{"subKey":"subVal"}}}'

got_pb = pb_json.Parse(bad_st, stpb.Statement())
got = its.Statement(got_pb)
got = Statement.copy_from_pb(got_pb)

with self.assertRaises(ValueError, msg='Error: created malformed Statement (bad predicate type)'):
got.validate()
Expand All @@ -79,7 +73,7 @@ def test_bad_predicate(self):
bad_st = '{"_type":"https://in-toto.io/Statement/v1","subject":[{"name":"theSub","digest":{"alg1":"abc123"}}],"predicateType":"thePredicate"}'

got_pb = pb_json.Parse(bad_st, stpb.Statement())
got = its.Statement(got_pb)
got = Statement.copy_from_pb(got_pb)

with self.assertRaises(ValueError, msg='Error: created malformed Statement (bad predicate)'):
got.validate()
Expand Down

0 comments on commit e5f2403

Please sign in to comment.