diff --git a/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_classpath_publisher.py b/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_classpath_publisher.py index b0f073f928e..24e19eede33 100644 --- a/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_classpath_publisher.py +++ b/src/python/pants/backend/jvm/tasks/jvm_compile/jvm_classpath_publisher.py @@ -3,6 +3,8 @@ import os +from twitter.common.collections import OrderedSet + from pants.backend.jvm.tasks.classpath_products import ClasspathProducts from pants.backend.jvm.tasks.classpath_util import ClasspathUtil from pants.java.util import safe_classpath @@ -17,6 +19,10 @@ def register_options(cls, register): super().register_options(register) register('--manifest-jar-only', type=bool, default=False, help='Only export classpath in a manifest jar.') + register('--transitive-only', type=bool, default=False, + help='Only export the classpath of the transitive dependencies of the target roots. ' + 'This avoids jarring up the target roots themselves, which allows an IDE to ' + 'insert their own modules more easily to cover the source files of target roots.') @classmethod def prepare(cls, options, round_manager): @@ -29,7 +35,10 @@ def _output_folder(self): def execute(self): basedir = os.path.join(self.get_options().pants_distdir, self._output_folder) runtime_classpath = self.context.products.get_data('runtime_classpath') - targets = self.context.targets() + + targets = OrderedSet(self.get_targets()) - set(self.context.target_roots) \ + if self.get_options().transitive_only else self.get_targets() + if self.get_options().manifest_jar_only: classpath = ClasspathUtil.classpath(targets, runtime_classpath) # Safely create e.g. dist/export-classpath/manifest.jar diff --git a/tests/python/pants_test/backend/jvm/tasks/jvm_compile/test_jvm_classpath_published.py b/tests/python/pants_test/backend/jvm/tasks/jvm_compile/test_jvm_classpath_published.py index 31813385c3a..a55b50c67d6 100644 --- a/tests/python/pants_test/backend/jvm/tasks/jvm_compile/test_jvm_classpath_published.py +++ b/tests/python/pants_test/backend/jvm/tasks/jvm_compile/test_jvm_classpath_published.py @@ -79,3 +79,53 @@ def test_incremental_caching(self): os.pathsep.join([os.path.join(jar_dir, 'z2.jar'), os.path.join(jar_dir, 'z3.jar')]) + '\n' ] ) + + def _assert_jars_created(self, *, transitive_only: bool) -> None: + with temporary_dir(root_dir=self.pants_workdir) as jar_dir, \ + temporary_dir(root_dir=self.pants_workdir) as dist_dir: + self.set_options(pants_distdir=dist_dir, + transitive_only=transitive_only) + + init_target = self.make_target( + 'java/classpath:java_lib', + target_type=JavaLibrary, + sources=['com/foo/Bar.java'], + ) + target_with_dep = self.make_target( + 'java/classpath:java_lib_with_dep', + target_type=JavaLibrary, + sources=['com/foo/Bar.java'], + dependencies=[init_target], + ) + context = self.context(target_roots=[target_with_dep]) + runtime_classpath = context.products.get_data('runtime_classpath', + init_func=ClasspathProducts.init_func(self.pants_workdir)) + task = self.create_task(context) + + target_classpath_output = os.path.join(dist_dir, self.options_scope) + + # Create a classpath entry. + touch(os.path.join(jar_dir, 'dep-target.jar')) + touch(os.path.join(jar_dir, 'root-target.jar')) + runtime_classpath.add_for_target(init_target, [(self.DEFAULT_CONF, os.path.join(jar_dir, 'dep-target.jar'))]) + runtime_classpath.add_for_target(target_with_dep, [(self.DEFAULT_CONF, os.path.join(jar_dir, 'root-target.jar'))]) + task.execute() + + all_output = os.listdir(target_classpath_output) + + # Check only one symlink and classpath.txt were created. + expected_artifacts = 2 if transitive_only else 4 + self.assertEqual(len(all_output), expected_artifacts) + + + self.assertIn('java.classpath.java_lib-0.jar', all_output) + if transitive_only: + self.assertNotIn('java.classpath.java_lib_with_dep-0.jar', all_output) + else: + self.assertIn('java.classpath.java_lib_with_dep-0.jar', all_output) + + def test_transitive_only(self): + self._assert_jars_created(transitive_only=True) + + def test_no_transitive_only(self): + self._assert_jars_created(transitive_only=False)