diff --git a/thrift/compiler/generate/t_ast_generator.cc b/thrift/compiler/generate/t_ast_generator.cc index a6c2b20954c..154d0d563d1 100644 --- a/thrift/compiler/generate/t_ast_generator.cc +++ b/thrift/compiler/generate/t_ast_generator.cc @@ -216,6 +216,10 @@ void t_ast_generator::generate_program() { auto program_id = positionToId(pos); program_index[&program] = program_id; hydrate_const(programs.emplace_back(), *schema_source.gen_schema(program)); + if (program.has_doc()) { + programs.back().attrs()->docs()->sourceRange() = + src_range(program.doc_range(), &program); + } auto is_root_program = std::exchange(is_root_program_, false); for (auto* include : program.get_included_programs()) { diff --git a/thrift/compiler/test/ast_generator_test.py b/thrift/compiler/test/ast_generator_test.py index 41bec504fa3..85021c7f916 100644 --- a/thrift/compiler/test/ast_generator_test.py +++ b/thrift/compiler/test/ast_generator_test.py @@ -167,7 +167,6 @@ def test_service(self): self.assertEqual(func.returnType.name.type, TypeName.Type.i32Type) func = serviceDef.functions[2] - print(func) self.assertEqual(func.attrs.name, "baz") self.assertEqual(len(func.returnTypes), 0) # streams excluded self.assertEqual(func.returnType.name.type, TypeName.Type.EMPTY) @@ -202,6 +201,8 @@ def test_docs(self): "foo.thrift", textwrap.dedent( """ + /** File docblock */ + namespace cpp2 foo /** This is a struct and its name is Foo. */ struct Foo { 1: i64 Bar; ///< This is a field and its name is Bar. @@ -211,19 +212,24 @@ def test_docs(self): ) ast = self.run_thrift("foo.thrift") - print(ast.definitions) + + docs = ast.programs[0].attrs.docs + self.assertEqual(docs.contents.rstrip(), "File docblock") + self.assertEqual(docs.sourceRange.programId, 1) + self.assertEqual(docs.sourceRange.beginLine, 2) + self.assertEqual(ast.definitions[0].structDef.attrs.name, "Foo") srcRange = ast.definitions[0].structDef.attrs.sourceRange self.assertEqual(srcRange.programId, 1) - self.assertEqual(srcRange.beginLine, 3) + self.assertEqual(srcRange.beginLine, 5) docs = ast.definitions[0].structDef.attrs.docs self.assertEqual( docs.contents.rstrip(), "This is a struct and its name is Foo." ) self.assertEqual(docs.sourceRange.programId, 1) - self.assertEqual(docs.sourceRange.beginLine, 2) + self.assertEqual(docs.sourceRange.beginLine, 4) self.assertEqual(docs.sourceRange.beginColumn, 1) - self.assertEqual(docs.sourceRange.endLine, 2) + self.assertEqual(docs.sourceRange.endLine, 4) self.assertEqual(docs.sourceRange.endColumn, 45) docs = ast.definitions[0].structDef.fields[0].attrs.docs @@ -231,9 +237,9 @@ def test_docs(self): if os.name == "nt": return # line separators differ on windows and mess up source ranges self.assertEqual(docs.sourceRange.programId, 1) - self.assertEqual(docs.sourceRange.beginLine, 4) + self.assertEqual(docs.sourceRange.beginLine, 6) self.assertEqual(docs.sourceRange.beginColumn, 18) - self.assertEqual(docs.sourceRange.endLine, 5) + self.assertEqual(docs.sourceRange.endLine, 7) self.assertEqual(docs.sourceRange.endColumn, 1) def test_program(self):