diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 9bea442be..9da02a569 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -16,6 +16,7 @@ # Register all preprocessors. import scripts.preprocessor as preprocessor_init # noqa from annotator.util import HWC3 +from internal_controlnet.external_code import ControlNetUnit from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils from scripts.controlnet_lora import bind_control_lora, unbind_control_lora from scripts.controlnet_lllite import clear_all_lllite @@ -228,7 +229,7 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor: def get_control( p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, idx: int, control_model_type: ControlModelType, preprocessor: Preprocessor, @@ -338,7 +339,7 @@ def __init__(self) -> None: self.latest_network = None self.input_image = None self.latest_model_hash = "" - self.enabled_units: List[external_code.ControlNetUnit] = [] + self.enabled_units: List[ControlNetUnit] = [] self.detected_map = [] self.post_processors = [] self.noise_modifier = None @@ -356,7 +357,7 @@ def show(self, is_img2img): @staticmethod def get_default_ui_unit(is_ui=True): - cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit + cls = UiControlNetUnit if is_ui else ControlNetUnit return cls( enabled=False, module="none", @@ -527,7 +528,7 @@ def get_element(obj, strict=False): return attribute_value if attribute_value is not None else default @staticmethod - def parse_remote_call(p, unit: external_code.ControlNetUnit, idx): + def parse_remote_call(p, unit: ControlNetUnit, idx): selector = Script.get_remote_call unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True) @@ -688,7 +689,7 @@ def get_enabled_units(p): @staticmethod def choose_input_image( p: processing.StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, idx: int ) -> Tuple[np.ndarray, ResizeMode]: """ Choose input image from following sources with descending priority: @@ -701,7 +702,7 @@ def choose_input_image( - The input image in ndarray form. - The resize mode. """ - def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: + def parse_unit_image(unit: ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: unit_has_multiple_images = ( isinstance(unit.image, list) and len(unit.image) > 0 and @@ -810,7 +811,7 @@ def decode_image(img) -> np.ndarray: @staticmethod def try_crop_image_with_a1111_mask( p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, input_image: np.ndarray, resize_mode: ResizeMode, ) -> np.ndarray: @@ -863,7 +864,7 @@ def try_crop_image_with_a1111_mask( return input_image @staticmethod - def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None: + def check_sd_version_compatible(unit: ControlNetUnit) -> None: """ Checks whether the given ControlNet unit has model compatible with the currently active sd model. An exception is thrown if ControlNet unit is detected to be diff --git a/scripts/controlnet_ui/controlnet_ui_group.py b/scripts/controlnet_ui/controlnet_ui_group.py index a30f3590c..e9dc95225 100644 --- a/scripts/controlnet_ui/controlnet_ui_group.py +++ b/scripts/controlnet_ui/controlnet_ui_group.py @@ -13,6 +13,7 @@ external_code, ) from annotator.util import HWC3 +from internal_controlnet.external_code import ControlNetUnit from scripts.logging import logger from scripts.controlnet_ui.openpose_editor import OpenposeEditor from scripts.controlnet_ui.preset import ControlNetPresetUI @@ -127,7 +128,7 @@ def set_component(self, component: gr.components.Component): ) -class UiControlNetUnit(external_code.ControlNetUnit): +class UiControlNetUnit(ControlNetUnit): """The data class that stores all states of a ControlNetUnit.""" def __init__( @@ -167,7 +168,7 @@ def __init__( self.output_dir = output_dir self.loopback = loopback - def unfold_merged(self) -> List[external_code.ControlNetUnit]: + def unfold_merged(self) -> List[ControlNetUnit]: """Unfolds a merged unit to multiple units. Keeps the unit merged for preprocessors that can accept multiple input images. """ @@ -220,7 +221,7 @@ class ControlNetUiGroup(object): def __init__( self, is_img2img: bool, - default_unit: external_code.ControlNetUnit, + default_unit: ControlNetUnit, photopea: Optional[Photopea], ): # Whether callbacks have been registered. @@ -1260,7 +1261,7 @@ def register_core_callbacks(self): self.type_filter, *[ getattr(self, key) - for key in vars(external_code.ControlNetUnit()).keys() + for key in vars(ControlNetUnit()).keys() ], ) self.advanced_weight_control.register_callbacks( diff --git a/scripts/controlnet_ui/preset.py b/scripts/controlnet_ui/preset.py index 3010d2617..9882acf53 100644 --- a/scripts/controlnet_ui/preset.py +++ b/scripts/controlnet_ui/preset.py @@ -5,9 +5,9 @@ from modules import scripts from modules.ui_components import ToolButton +from internal_controlnet.external_code import ControlNetUnit from scripts.infotext import parse_unit, serialize_unit from scripts.logging import logger -from scripts import external_code from scripts.supported_preprocessor import Preprocessor save_symbol = "\U0001f4be" # 💾 @@ -113,7 +113,7 @@ def apply_preset(name: str, control_type: str, *ui_states): gr.update(visible=False), *( (gr.skip(),) - * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + * (len(vars(ControlNetUnit()).keys()) + 1) ), ) @@ -121,7 +121,7 @@ def apply_preset(name: str, control_type: str, *ui_states): infotext = ControlNetPresetUI.presets[name] preset_unit = parse_unit(infotext) - current_unit = external_code.ControlNetUnit(*ui_states) + current_unit = ControlNetUnit(*ui_states) preset_unit.image = None current_unit.image = None @@ -136,7 +136,7 @@ def apply_preset(name: str, control_type: str, *ui_states): gr.update(visible=False), *( (gr.skip(),) - * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + * (len(vars(ControlNetUnit()).keys()) + 1) ), ) @@ -177,7 +177,7 @@ def save_preset(name: str, *ui_states): return gr.update(visible=True), gr.update(), gr.update() ControlNetPresetUI.save_preset( - name, external_code.ControlNetUnit(*ui_states) + name, ControlNetUnit(*ui_states) ) return ( gr.update(), # name dialog @@ -222,7 +222,7 @@ def save_new_preset(new_name: str, *ui_states): return gr.update(visible=False), gr.update() ControlNetPresetUI.save_preset( - new_name, external_code.ControlNetUnit(*ui_states) + new_name, ControlNetUnit(*ui_states) ) return gr.update(visible=False), gr.update( choices=ControlNetPresetUI.dropdown_choices(), value=new_name @@ -248,7 +248,7 @@ def update_reset_button(preset_name: str, *ui_states): infotext = ControlNetPresetUI.presets[preset_name] preset_unit = parse_unit(infotext) - current_unit = external_code.ControlNetUnit(*ui_states) + current_unit = ControlNetUnit(*ui_states) preset_unit.image = None current_unit.image = None @@ -279,7 +279,7 @@ def dropdown_choices() -> List[str]: return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET] @staticmethod - def save_preset(name: str, unit: external_code.ControlNetUnit): + def save_preset(name: str, unit: ControlNetUnit): infotext = serialize_unit(unit) with open( os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w" diff --git a/scripts/infotext.py b/scripts/infotext.py index 9dbdad1d8..f0f3f212b 100644 --- a/scripts/infotext.py +++ b/scripts/infotext.py @@ -4,7 +4,7 @@ from modules.processing import StableDiffusionProcessing -from scripts import external_code +from internal_controlnet.external_code import ControlNetUnit from scripts.logging import logger @@ -28,12 +28,12 @@ def parse_value(value: str) -> Union[str, float, int, bool]: return value # Plain string. -def serialize_unit(unit: external_code.ControlNetUnit) -> str: - excluded_fields = external_code.ControlNetUnit.infotext_excluded_fields() +def serialize_unit(unit: ControlNetUnit) -> str: + excluded_fields = ControlNetUnit.infotext_excluded_fields() log_value = { field_to_displaytext(field): getattr(unit, field) - for field in vars(external_code.ControlNetUnit()).keys() + for field in vars(ControlNetUnit()).keys() if field not in excluded_fields and getattr(unit, field) != -1 # Note: exclude hidden slider values. } @@ -44,8 +44,8 @@ def serialize_unit(unit: external_code.ControlNetUnit) -> str: return ", ".join(f"{field}: {value}" for field, value in log_value.items()) -def parse_unit(text: str) -> external_code.ControlNetUnit: - return external_code.ControlNetUnit( +def parse_unit(text: str) -> ControlNetUnit: + return ControlNetUnit( enabled=True, **{ displaytext_to_field(key): parse_value(value) @@ -74,7 +74,7 @@ def register_unit(self, unit_index: int, uigroup) -> None: iocomponents. """ unit_prefix = Infotext.unit_prefix(unit_index) - for field in vars(external_code.ControlNetUnit()).keys(): + for field in vars(ControlNetUnit()).keys(): # Exclude image for infotext. if field == "image": continue @@ -88,7 +88,7 @@ def register_unit(self, unit_index: int, uigroup) -> None: @staticmethod def write_infotext( - units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing + units: List[ControlNetUnit], p: StableDiffusionProcessing ): """Write infotext to `p`.""" p.extra_generation_params.update( diff --git a/tests/cn_script/batch_hijack_test.py b/tests/cn_script/batch_hijack_test.py index 0f68fe5bc..32afd51f4 100644 --- a/tests/cn_script/batch_hijack_test.py +++ b/tests/cn_script/batch_hijack_test.py @@ -6,6 +6,7 @@ from modules import processing, scripts, shared +from internal_controlnet.external_code import ControlNetUnit from scripts import controlnet, external_code, batch_hijack @@ -73,15 +74,15 @@ def test_get_cn_batches__empty(self): self.assertEqual(is_batch, False) def test_get_cn_batches__1_simple(self): - self.p.script_args.append(external_code.ControlNetUnit(image=get_dummy_image())) + self.p.script_args.append(ControlNetUnit(image=get_dummy_image())) self.assert_get_cn_batches_works([ [self.p.script_args[0].image], ]) def test_get_cn_batches__2_simples(self): self.p.script_args.extend([ - external_code.ControlNetUnit(image=get_dummy_image(0)), - external_code.ControlNetUnit(image=get_dummy_image(1)), + ControlNetUnit(image=get_dummy_image(0)), + ControlNetUnit(image=get_dummy_image(1)), ]) self.assert_get_cn_batches_works([ [get_dummy_image(0)], @@ -135,7 +136,7 @@ def test_get_cn_batches__2_batches(self): def test_get_cn_batches__2_mixed(self): self.p.script_args.extend([ - external_code.ControlNetUnit(image=get_dummy_image(0)), + ControlNetUnit(image=get_dummy_image(0)), controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ @@ -157,7 +158,7 @@ def test_get_cn_batches__2_mixed(self): def test_get_cn_batches__3_mixed(self): self.p.script_args.extend([ - external_code.ControlNetUnit(image=get_dummy_image(0)), + ControlNetUnit(image=get_dummy_image(0)), controlnet.UiControlNetUnit( input_mode=batch_hijack.InputMode.BATCH, batch_images=[ @@ -242,8 +243,8 @@ def test_process_images_no_units_forwards(self): def test_process_images__only_simple_units__forwards(self): self.p.script_args = [ - external_code.ControlNetUnit(image=get_dummy_image()), - external_code.ControlNetUnit(image=get_dummy_image()), + ControlNetUnit(image=get_dummy_image()), + ControlNetUnit(image=get_dummy_image()), ] self.assert_process_images_hijack_called(batch_count=0) diff --git a/tests/cn_script/cn_script_test.py b/tests/cn_script/cn_script_test.py index 23549bf15..e3d769335 100644 --- a/tests/cn_script/cn_script_test.py +++ b/tests/cn_script/cn_script_test.py @@ -10,6 +10,7 @@ from scripts import external_code from scripts.enums import ResizeMode from scripts.controlnet import prepare_mask, Script, set_numpy_seed +from internal_controlnet.external_code import ControlNetUnit from modules import processing @@ -127,7 +128,7 @@ def test_choose_input_image(self): with self.assertRaises(ValueError): Script.choose_input_image( p=processing.StableDiffusionProcessing(), - unit=external_code.ControlNetUnit(), + unit=ControlNetUnit(), idx=0, ) @@ -137,7 +138,7 @@ def test_choose_input_image(self): init_images=[TestScript.sample_np_image], resize_mode=ResizeMode.OUTER_FIT, ), - unit=external_code.ControlNetUnit( + unit=ControlNetUnit( image=TestScript.sample_base64_image, module="none", resize_mode=ResizeMode.INNER_FIT, @@ -152,7 +153,7 @@ def test_choose_input_image(self): init_images=[TestScript.sample_np_image], resize_mode=ResizeMode.OUTER_FIT, ), - unit=external_code.ControlNetUnit( + unit=ControlNetUnit( module="none", resize_mode=ResizeMode.INNER_FIT, ), diff --git a/tests/external_code_api/external_code_test.py b/tests/external_code_api/external_code_test.py index 4772e3b53..a58eb4de5 100644 --- a/tests/external_code_api/external_code_test.py +++ b/tests/external_code_api/external_code_test.py @@ -10,6 +10,7 @@ from scripts import external_code from scripts import controlnet from scripts.enums import ResizeMode +from internal_controlnet.external_code import ControlNetUnit from modules import scripts, ui, shared @@ -49,7 +50,7 @@ def test_empty_resizes_min_args(self): def test_empty_resizes_extra_args(self): extra_models = 1 - self.cn_units = [external_code.ControlNetUnit()] * (self.max_models + extra_models) + self.cn_units = [ControlNetUnit()] * (self.max_models + extra_models) self.assert_update_in_place_ok() @@ -57,7 +58,7 @@ class TestControlNetUnitConversion(unittest.TestCase): def setUp(self): self.dummy_image = 'base64...' self.input = {} - self.expected = external_code.ControlNetUnit() + self.expected = ControlNetUnit() def assert_converts_to_expected(self): self.assertEqual(vars(external_code.to_processing_unit(self.input)), vars(self.expected)) @@ -69,14 +70,14 @@ def test_image_works(self): self.input = { 'image': self.dummy_image } - self.expected = external_code.ControlNetUnit(image=self.dummy_image) + self.expected = ControlNetUnit(image=self.dummy_image) self.assert_converts_to_expected() def test_image_alias_works(self): self.input = { 'input_image': self.dummy_image } - self.expected = external_code.ControlNetUnit(image=self.dummy_image) + self.expected = ControlNetUnit(image=self.dummy_image) self.assert_converts_to_expected() def test_masked_image_works(self): @@ -84,14 +85,14 @@ def test_masked_image_works(self): 'image': self.dummy_image, 'mask': self.dummy_image, } - self.expected = external_code.ControlNetUnit(image={'image': self.dummy_image, 'mask': self.dummy_image}) + self.expected = ControlNetUnit(image={'image': self.dummy_image, 'mask': self.dummy_image}) self.assert_converts_to_expected() class TestControlNetUnitImageToDict(unittest.TestCase): def setUp(self): self.dummy_image = utils.readImage("test/test_files/img2img_basic.png") - self.input = external_code.ControlNetUnit() + self.input = ControlNetUnit() self.expected_image = external_code.to_base64_nparray(self.dummy_image) self.expected_mask = external_code.to_base64_nparray(self.dummy_image) @@ -143,7 +144,7 @@ def test_bool(self): self.assertListEqual(external_code.get_all_units_from([True]), []) def test_inheritance(self): - class Foo(external_code.ControlNetUnit): + class Foo(ControlNetUnit): def __init__(self): super().__init__(self) self.bar = 'a' @@ -154,14 +155,14 @@ def __init__(self): def test_dict(self): units = external_code.get_all_units_from([{}]) self.assertGreater(len(units), 0) - self.assertIsInstance(units[0], external_code.ControlNetUnit) + self.assertIsInstance(units[0], ControlNetUnit) def test_unitlike(self): class Foo(object): """ bar """ foo = Foo() - for key in vars(external_code.ControlNetUnit()).keys(): + for key in vars(ControlNetUnit()).keys(): setattr(foo, key, True) setattr(foo, 'bar', False) self.assertListEqual(external_code.get_all_units_from([foo]), [foo]) diff --git a/tests/external_code_api/script_args_test.py b/tests/external_code_api/script_args_test.py index 97c58f6d2..40075aa65 100644 --- a/tests/external_code_api/script_args_test.py +++ b/tests/external_code_api/script_args_test.py @@ -5,6 +5,7 @@ from scripts import external_code from scripts.enums import ControlMode +from internal_controlnet.external_code import ControlNetUnit class TestGetAllUnitsFrom(unittest.TestCase): @@ -18,7 +19,7 @@ def setUp(self): "processor_res": 64, "control_mode": ControlMode.BALANCED.value, } - self.object_unit = external_code.ControlNetUnit(**self.control_unit) + self.object_unit = ControlNetUnit(**self.control_unit) def test_empty_converts(self): script_args = []