Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const ParamLoraCollapse = () => {
}

return (
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
<IAICollapse label="LoRA" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamLoRASelect />
<ParamLoraList />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Divider } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
Expand All @@ -8,20 +9,21 @@ import ParamLora from './ParamLora';
const selector = createSelector(
stateSelector,
({ lora }) => {
const { loras } = lora;

return { loras };
return { lorasArray: map(lora.loras) };
},
defaultSelectorOptions
);

const ParamLoraList = () => {
const { loras } = useAppSelector(selector);
const { lorasArray } = useAppSelector(selector);

return (
<>
{map(loras, (lora) => (
<ParamLora key={lora.model_name} lora={lora} />
{lorasArray.map((lora, i) => (
<>
{i > 0 && <Divider key={`${lora.model_name}-divider`} pt={1} />}
<ParamLora key={lora.model_name} lora={lora} />
</>
))}
</>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import {
CLIP_SKIP,
LORA_LOADER,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
Expand All @@ -36,15 +35,11 @@ export const addLoRAsToGraph = (
| undefined;

if (loraCount > 0) {
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === MAIN_MODEL_LOADER &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === ONNX_MODEL_LOADER &&
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import {
MetadataAccumulatorInvocation,
SDXLLoraLoaderInvocation,
} from 'services/api/types';
import {
LORA_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
} from './constants';

export const addSDXLLoRAsToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = SDXL_MODEL_LOADER
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/

const { loras } = state.lora;
const loraCount = size(loras);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;

if (loraCount > 0) {
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['clip'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['clip2'].includes(e.source.field)
)
);
}

// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;

forEach(loras, (lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;

const loraLoaderNode: SDXLLoraLoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
weight,
};

// add the lora to the metadata accumulator
if (metadataAccumulator) {
metadataAccumulator.loras.push({
lora: { model_name, base_model },
weight,
});
}

// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});

graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});

graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip2',
},
});
} else {
// we are in the middle of the lora chain, instead connect to the previous lora
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});

graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip2',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip2',
},
});
}

if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'unet',
},
destination: {
node_id: baseNodeId,
field: 'unet',
},
});

graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
});

graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
});

graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
});

graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
});
}

// increment the lora for the next one in the chain
lastLoraNodeId = currentLoraNodeId;
currentLoraIndex += 1;
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
} from './constants';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';

/**
* Builds the Image to Image tab graph.
Expand Down Expand Up @@ -364,6 +365,8 @@ export const buildLinearSDXLImageToImageGraph = (
},
});

addSDXLLoRAsToGraph(state, graph, SDXL_LATENTS_TO_LATENTS, SDXL_MODEL_LOADER);

// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
Expand Down Expand Up @@ -246,6 +247,8 @@ export const buildLinearSDXLTextToImageGraph = (
},
});

addSDXLLoRAsToGraph(state, graph, SDXL_TEXT_TO_LATENTS, SDXL_MODEL_LOADER);

// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';

const SDXLImageToImageTabParameters = () => {
return (
Expand All @@ -12,6 +13,7 @@ const SDXLImageToImageTabParameters = () => {
<ProcessButtons />
<SDXLImageToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';

const SDXLTextToImageTabParameters = () => {
return (
Expand All @@ -12,6 +13,7 @@ const SDXLTextToImageTabParameters = () => {
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
Expand Down
Loading