-
Notifications
You must be signed in to change notification settings - Fork 12
/
SubtitleTranslator.scala
234 lines (203 loc) · 8.95 KB
/
SubtitleTranslator.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
package tools
import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.stream.scaladsl.{FileIO, Flow, Source}
import org.apache.pekko.stream.{IOResult, ThrottleMode}
import org.apache.pekko.util.ByteString
import org.slf4j.{Logger, LoggerFactory}
import java.nio.file.Paths
import scala.concurrent.duration.DurationInt
import scala.concurrent.{ExecutionContextExecutor, Future}
import scala.util.{Failure, Success}
/**
* Translate all blocks of an English .srt file to a target lang using OpenAI API
*
* Workflow:
* - Load all blocks from the .srt source file with [[SrtParser]]
* - Group blocks to scenes (= all blocks within a session window), depending on `maxGapSeconds`
* - Translate all blocks of a scene in one prompt (one line per block) via the openAI API
* - Continuously write translated blocks to target file
*
* Works with these OpenAI API endpoints:
* - Default: /chat/completions (gpt-3.5-turbo) https://platform.openai.com/docs/guides/chat/chat-vs-completions
* - Fallback: /completions (gpt-3.5-turbo-instruct) https://platform.openai.com/docs/api-reference/completions/create
*
* Usage:
* - Wire .srt source file
* - Add API_KEY in [[OpenAICompletions]] and run this class
* - Scan log for WARN log messages and improve corresponding blocks in target file manually
* - Note that the numerical block headers in the .srt files are not interpreted, only timestamps matter
*
* Similar to: [[sample.stream.SessionWindow]]
*/
object SubtitleTranslator extends App {
val logger: Logger = LoggerFactory.getLogger(this.getClass)
implicit val system: ActorSystem = ActorSystem()
implicit val executionContext: ExecutionContextExecutor = system.dispatcher
val sourceFilePath = "src/main/resources/EN_challenges.srt"
private val targetFilePath = "DE_challenges.srt"
private val targetLanguage = "German"
private val defaultModel = "gpt-4o"
private val fallbackModel = "gpt-4-turbo"
private val maxGapSeconds = 1 // gap time between two scenes (= session windows)
private val endLineTag = "\n"
private val maxCharPerTranslatedLine = 40 // recommendation
private val conversationPrefix = "-"
private var totalTokensUsed = 0
// Sync to ensure that all blocks are readable before translation starts
val parseResult = SrtParser(sourceFilePath).runSync()
logger.info("Number of subtitleBlocks to translate: {}", parseResult.length)
val source = Source(parseResult)
val workflow = Flow[SubtitleBlock]
.via(groupByScene(maxGapSeconds))
.map(translateScene)
val fileSink = FileIO.toPath(Paths.get(targetFilePath))
val processingSink = Flow[SubtitleBlock]
.zipWithIndex
.map { case (block: SubtitleBlock, blockCounter: Long) =>
ByteString(block.formatOutBlock(blockCounter + 1))
}
.toMat(fileSink)((_, bytesWritten) => bytesWritten)
val done = source
// https://platform.openai.com/docs/guides/rate-limits/overview
.throttle(25, 60.seconds, 25, ThrottleMode.shaping)
.via(workflow)
.mapConcat(identity) // flatten
.runWith(processingSink)
terminateWhen(done)
// Partition to session windows
private def groupByScene(maxGap: Int) = {
Flow[SubtitleBlock].statefulMap(() => List.empty[SubtitleBlock])(
(stateList, nextElem) => {
val newStateList = stateList :+ nextElem
val lastElem = if (stateList.isEmpty) nextElem else stateList.reverse.head
val calcGap = nextElem.start - lastElem.end
if (calcGap < maxGap * 1000) {
// (list for next iteration, list of output elements)
(newStateList, Nil)
}
else {
// (list for next iteration, list of output elements)
(List(nextElem), stateList)
}
},
// Cleanup function, we return the last stateList
stateList => Some(stateList))
.filterNot(scene => scene.isEmpty)
}
private def translateScene(sceneOrig: List[SubtitleBlock]) = {
logger.info(s"About to translate scene with: ${sceneOrig.size} original blocks")
val allLines = sceneOrig.foldLeft("")((acc, block) => acc + block.allLinesEnd)
val toTranslate = generateTranslationPrompt(allLines)
logger.info(s"Translation prompt: $toTranslate")
val translatedCheap = new OpenAICompletions().runChatCompletions(defaultModel, toTranslate)
val translated = translatedCheap match {
case translatedCheap if !isTranslationPlausible(translatedCheap.getLeft, sceneOrig.size) =>
logger.info(s"Translation with: $defaultModel is not plausible, lines do not match. Fallback to: $fallbackModel")
new OpenAICompletions().runChatCompletions(fallbackModel, toTranslate)
case _ => translatedCheap
}
val newTokens = translated.getRight
totalTokensUsed = totalTokensUsed + newTokens
val rawResponseText = translated.getLeft
logger.debug("Response text: {}", rawResponseText)
val seed: Vector[SubtitleBlock] = Vector.empty
val sceneTranslated: Vector[SubtitleBlock] =
rawResponseText
.split(endLineTag)
.filterNot(each => each.isEmpty)
.zipWithIndex
.foldLeft(seed) { (acc: Vector[SubtitleBlock], rawResponseTextSplit: (String, Int)) =>
val massagedResult = massageResultText(rawResponseTextSplit._1)
val origBlock =
if (sceneOrig.isDefinedAt(rawResponseTextSplit._2)) {
sceneOrig(rawResponseTextSplit._2)
} else {
// Root cause: No plausible translation provided by openAI, eg due to added lines at beginning or at end of response
logger.warn(s"This should not happen: sceneOrig has size: ${sceneOrig.size} but access to element: ${rawResponseTextSplit._2} requested. Fallback to last original block")
sceneOrig.last
}
val translatedBlock = origBlock.copy(lines = massagedResult)
logger.info(s"Translated block to: ${translatedBlock.allLines}")
acc.appended(translatedBlock)
}
logger.info(s"Finished translation of scene with: ${sceneTranslated.size} blocks")
sceneTranslated
}
private def isTranslationPlausible(rawResponseText: String, originalSize: Int) = {
val resultSize = rawResponseText
.split(endLineTag)
.filterNot(each => each.isEmpty)
.length
resultSize == originalSize
}
private def generateTranslationPrompt(text: String) = {
s"""
|Translate the text lines below from English to $targetLanguage.
|
|Desired format:
|<line separated list of translated text lines, honor all line breaks>
|
|Text lines:
|$text
|
|""".stripMargin
}
private def generateShortenPrompt(text: String) = {
s"""
|Rewrite to ${maxCharPerTranslatedLine * 2 - 10} characters at most:
|$text
|
|""".stripMargin
}
private def massageResultText(text: String) = {
val textCleaned = clean(text)
// Two people conversation in one block
if (textCleaned.startsWith(conversationPrefix)) {
textCleaned.split(conversationPrefix).map(line => conversationPrefix + line).toList.tail
}
else if (textCleaned.length > maxCharPerTranslatedLine * 2 + 10) {
logger.warn(s"Translated block text is too long (${textCleaned.length} chars). Try to shorten via API call. Check result manually")
val toShorten = generateShortenPrompt(textCleaned)
logger.info(s"Shorten prompt: $toShorten")
val responseShort = new OpenAICompletions().runChatCompletions(defaultModel, toShorten)
splitSentence(clean(responseShort.getLeft))
}
else splitSentence(textCleaned)
}
private def clean(text: String) = {
val filtered = text.filter(_ >= ' ')
if (filtered.startsWith("\"")) filtered.substring(1, filtered.length() - 1)
else filtered
}
private def splitSentence(text: String) = {
if (text.length > maxCharPerTranslatedLine && text.contains(",")) {
val indexFirstComma = text.indexOf(",")
val offset = 15
if (indexFirstComma > offset && indexFirstComma < text.length - offset)
List(text.substring(0, indexFirstComma + 1), text.substring(indexFirstComma + 1, text.length))
else splitSentenceHonorWords(text)
}
else if (text.length > maxCharPerTranslatedLine) {
splitSentenceHonorWords(text)
} else {
List(text)
}
}
private def splitSentenceHonorWords(sentence: String) = {
val words = sentence.split(" ")
val mid = words.length / 2
val firstHalf = words.slice(0, mid).mkString(" ")
val secondHalf = words.slice(mid, words.length).mkString(" ")
List(firstHalf, secondHalf)
}
def terminateWhen(done: Future[IOResult]): Unit = {
done.onComplete {
case Success(_) =>
logger.info(s"Flow Success. Finished writing to target file: $targetFilePath. Around $totalTokensUsed tokens used. About to terminate...")
system.terminate()
case Failure(e) =>
logger.info(s"Flow Failure: $e. Partial translations are in target file: $targetFilePath About to terminate...")
system.terminate()
}
}
}