Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add sd-webui-dtg extension support #263

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,66 @@ export const PromptConfig: Schema<PromptConfig> = Schema.object({
maxWords: Schema.computed(Schema.natural(), options).description('允许的最大单词数量。').default(0),
}).description('输入设置')

const dtgModels = [
'KBlueLeaf/DanTagGen-delta-rev2',
'KBlueLeaf/DanTagGen-delta',
'KBlueLeaf/DanTagGen-beta',
'KBlueLeaf/DanTagGen-alpha',
'KBlueLeaf/DanTagGen-gamma',
'KBlueLeaf/DanTagGen-delta-rev2|ggml-model-Q6_K.gguf',
'KBlueLeaf/DanTagGen-delta-rev2|ggml-model-Q8_0.gguf',
'KBlueLeaf/DanTagGen-delta-rev2|ggml-model-f16.gguf',
'KBlueLeaf/DanTagGen-delta|ggml-model-Q6_K.gguf',
'KBlueLeaf/DanTagGen-delta|ggml-model-Q8_0.gguf',
'KBlueLeaf/DanTagGen-delta|ggml-model-f16.gguf',
'KBlueLeaf/DanTagGen-beta|ggml-model-Q6_K.gguf',
'KBlueLeaf/DanTagGen-beta|ggml-model-Q8_0.gguf',
'KBlueLeaf/DanTagGen-beta|ggml-model-f16.gguf',
'KBlueLeaf/DanTagGen-alpha|ggml-model-Q6_K.gguf',
'KBlueLeaf/DanTagGen-alpha|ggml-model-Q8_0.gguf',
'KBlueLeaf/DanTagGen-alpha|ggml-model-f16.gguf',
'KBlueLeaf/DanTagGen-gamma|ggml-model-Q6_K.gguf',
'KBlueLeaf/DanTagGen-gamma|ggml-model-Q8_0.gguf',
'KBlueLeaf/DanTagGen-gamma|ggml-model-f16.gguf',
] as const

const dtgTagLengths = ['very short', 'short', 'long', 'very long'] as const

export interface DanTagGenConfig {
enabled: boolean
disableAfterLen?: number
totalTagLength?: 'very short' | 'short' | 'long' | 'very long'
banTags?: string
promptFormat?: string
upsamplingTagsSeed?: number
upsamplingTiming?: 'Before' | 'After'
model?: string
useCpu?: boolean
noFormatting?: boolean
temperature?: number
topP?: number
topK?: number
}

export const danTagGenConfig = Schema.object({
enabled: Schema.const(true).required(),
disableAfterLen: Schema.number().role('slider').min(0).max(200).description('在输入字数大于该值后不再启用tag生成。').default(100),
totalTagLength: Schema.union(dtgTagLengths).description('生成的tag长度。').default('short'),
banTags: Schema.string().role('textarea').description('禁止生成的tag,使用","分隔,支持Regex。').default(''),
promptFormat: Schema.string().role('textarea').description('生成的格式。').default('<|special|>, <|characters|>,<|artist|>, <|general|>, <|meta|>, <|rating|>'),
upsamplingTagsSeed: Schema.number().description('上采样种子。').default(-1),
upsamplingTiming: Schema.union([
Schema.const('Before').description('其他提示词处理前'),
Schema.const('After').description('其他提示词处理后'),
]).description('上采样介入时机。').default('After'),
model: Schema.union(dtgModels).description('生成模型。').default('KBlueLeaf/DanTagGen-delta-rev2'),
useCpu: Schema.boolean().description('是否使用CPU(GGUF)。').default(false),
noFormatting: Schema.boolean().description('是否禁用格式化。').default(false),
temperature: Schema.number().role('slider').min(0.1).max(2.5).step(0.01).description('温度。').default(1),
topP: Schema.number().role('slider').min(0).max(1).step(0.01).description('Top P。').default(0.8),
topK: Schema.number().role('slider').min(0).max(200).description('Top K。').default(80),
}).description('sd-webui-dtg 配置') as Schema<DanTagGenConfig>

interface FeatureConfig {
anlas?: Computed<boolean>
text?: Computed<boolean>
Expand Down Expand Up @@ -249,6 +309,7 @@ export interface Config extends PromptConfig, ParamConfig {
trustedWorkers?: boolean
workflowText2Image?: string
workflowImage2Image?: string
danTagGen?: DanTagGenConfig,
}

export const Config = Schema.intersect([
Expand Down Expand Up @@ -425,6 +486,22 @@ export const Config = Schema.intersect([
recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0),
maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0),
}).description('高级设置'),


Schema.union([
Schema.object({
type: Schema.const('sd-webui'),
danTagGen: danTagGenConfig,
})
]),
Schema.union([
Schema.object({
type: Schema.const('sd-webui'),
danTagGen: Schema.object({
enabled: Schema.boolean().description('是否启用 sd-webui-dtg 插件。').default(false),
}).description('sd-webui 插件设置'),
}),
]),
]) as Schema<Config>

interface Forbidden {
Expand Down Expand Up @@ -531,13 +608,14 @@ export function parseInput(session: Session, input: string, config: Config, over
return ['.too-many-words']
}

const sanitizedInput = positive.join(',')
if (!override) {
appendToList(positive, session.resolve(config.basePrompt))
appendToList(negative, session.resolve(config.negativePrompt))
if (config.defaultPromptSw) appendToList(positive, session.resolve(config.defaultPrompt))
}

return [null, positive.join(', '), negative.join(', ')]
return [null, positive.join(', '), negative.join(', '), sanitizedInput]
}

function getWordCount(words: string[]) {
Expand Down
13 changes: 8 additions & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Computed, Context, Dict, h, Logger, omit, Quester, Session, SessionErro
import { Config, modelMap, models, orientMap, parseInput, sampler, upscalers, scheduler } from './config'
import { ImageData, StableDiffusionWebUI } from './types'
import { closestMultiple, download, forceDataPrefix, getImageSize, login, NetworkError, project, resizeInput, Size } from './utils'
import { genExtensionsArgs } from './webuiExtension'
import { } from '@koishijs/translator'
import { } from '@koishijs/plugin-help'
import AdmZip from 'adm-zip'
Expand Down Expand Up @@ -118,10 +119,10 @@ export function apply(ctx: Context, config: Config) {
type: ['token', 'login'].includes(config.type)
? scheduler.nai
: config.type === 'sd-webui'
? scheduler.sd
: config.type === 'stable-horde'
? scheduler.horde
: [],
? scheduler.sd
: config.type === 'stable-horde'
? scheduler.horde
: [],
})
.option('decrisper', '-D', { hidden: thirdParty })
.option('undesired', '-u <undesired>')
Expand Down Expand Up @@ -205,7 +206,7 @@ export function apply(ctx: Context, config: Config) {
}
}

const [errPath, prompt, uc] = parseInput(session, input, config, options.override)
const [errPath, prompt, uc, sanitizedInput] = parseInput(session, input, config, options.override)
if (errPath) return session.text(errPath)

let token: string
Expand Down Expand Up @@ -360,6 +361,7 @@ export function apply(ctx: Context, config: Config) {
return { model, input: prompt, parameters: omit(parameters, ['prompt']) }
}
case 'sd-webui': {
const extensionsArgs = genExtensionsArgs(session, config, sanitizedInput)
return {
sampler_index: sampler.sd[options.sampler],
scheduler: options.scheduler,
Expand All @@ -368,6 +370,7 @@ export function apply(ctx: Context, config: Config) {
enable_hr: options.hiresFix ?? config.hiresFix ?? false,
hr_second_pass_steps: options.hiresFixSteps ?? 0,
hr_upscaler: config.hiresFixUpscaler ?? 'None',
alwayson_scripts: extensionsArgs,
...project(parameters, {
prompt: 'prompt',
batch_size: 'n_samples',
Expand Down
54 changes: 54 additions & 0 deletions src/webuiExtension.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { Dict, Session } from "koishi"
import { Config, DanTagGenConfig } from "./config"

type WebUiExtensionArgs = Array<boolean | number | string | object | null>

interface ExtensionGenRes {
name: string,
args: WebUiExtensionArgs
}

interface WebUiExtensionParams {
[key: string]: { args: WebUiExtensionArgs }
}

function danTagGen(session: Session, config: Config, prompt: string): ExtensionGenRes | false {
if (!config.danTagGen || !config.danTagGen.enabled) return false

const disableAfterLen = config.danTagGen.disableAfterLen
if (disableAfterLen > 0 && (prompt && prompt.length > disableAfterLen)) return false
const danTagGenConfig = config.danTagGen as DanTagGenConfig
return {
name: 'DanTagGen',
args: [
true, // enable
danTagGenConfig.upsamplingTiming === 'After'
? 'After applying other prompt processings'
: 'Before applying other prompt processings',
danTagGenConfig.upsamplingTagsSeed,
danTagGenConfig.totalTagLength,
danTagGenConfig.banTags,
danTagGenConfig.promptFormat,
danTagGenConfig.temperature,
danTagGenConfig.topP,
danTagGenConfig.topK,
danTagGenConfig.model,
danTagGenConfig.useCpu,
danTagGenConfig.noFormatting,
]
}
}

const webUiExtensions: Array<(session: Session, config: Config, prompt: string) => ExtensionGenRes | false> = [danTagGen]

export function genExtensionsArgs(session: Session, config: Config, prompt: string): WebUiExtensionParams {
if (!config.danTagGen) return {}
const args: WebUiExtensionParams = {}
for (const extension of webUiExtensions) {
const _args = extension(session, config, prompt)
if (_args) {
args[_args.name] = { args: _args.args }
}
}
return args
}