diff --git a/.dockerignore b/.dockerignore
index 9ff771b..b14ec5b 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,7 +1,6 @@
*.txt
.vscode
build
-onnx*
tmp
lake-packages
.lake
diff --git a/.gitignore b/.gitignore
index 22dc57e..519f0f9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,4 +2,4 @@
/.lake
.vscode
-*.olean
\ No newline at end of file
+*.olean
diff --git a/Dockerfile b/Dockerfile
index 7d0b7b3..777969c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,6 @@
FROM ubuntu:latest
-WORKDIR /LeanInfer
+WORKDIR /LeanCopilot
COPY . .
# Install dependencies.
@@ -14,5 +14,5 @@ RUN curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -
# Build the Lean project.
RUN lake build
-RUN lake script run LeanInfer/download
-RUN lake build LeanInferTests
\ No newline at end of file
+RUN lake script run LeanCopilot/download
+RUN lake build LeanCopilotTests
\ No newline at end of file
diff --git a/LeanCopilot.lean b/LeanCopilot.lean
new file mode 100644
index 0000000..3091005
--- /dev/null
+++ b/LeanCopilot.lean
@@ -0,0 +1,5 @@
+import LeanCopilot.Models
+import LeanCopilot.Frontend
+import LeanCopilot.Options
+import LeanCopilot.Tactics
+import LeanCopilot.LlmAesop
diff --git a/LeanInfer/Frontend.lean b/LeanCopilot/Frontend.lean
similarity index 100%
rename from LeanInfer/Frontend.lean
rename to LeanCopilot/Frontend.lean
diff --git a/LeanCopilot/LlmAesop.lean b/LeanCopilot/LlmAesop.lean
new file mode 100644
index 0000000..0e81757
--- /dev/null
+++ b/LeanCopilot/LlmAesop.lean
@@ -0,0 +1,23 @@
+import LeanCopilot.Tactics
+import LeanCopilot.Options
+import Aesop
+
+open Lean Lean.Elab.Command
+
+namespace LeanCopilot
+
+
+def tacGen : Aesop.TacGen := fun (mvarId : MVarId) => do
+ let state ← ppTacticState [mvarId]
+ let nm ← SuggestTactics.getGeneratorName
+ let model ← getGenerator nm
+ generate model state ""
+
+
+macro "#init_llm_aesop" : command => `(@[aesop 100%] def tacGen := LeanCopilot.tacGen)
+
+
+macro "search_proof" : tactic => `(tactic| aesop?)
+
+
+end LeanCopilot
diff --git a/LeanCopilot/Models.lean b/LeanCopilot/Models.lean
new file mode 100644
index 0000000..e2f5b3d
--- /dev/null
+++ b/LeanCopilot/Models.lean
@@ -0,0 +1,4 @@
+import LeanCopilot.Models.Interface
+import LeanCopilot.Models.Defs
+import LeanCopilot.Models.Registry
+import LeanCopilot.Models.FFI
diff --git a/LeanCopilot/Models/Builtin.lean b/LeanCopilot/Models/Builtin.lean
new file mode 100644
index 0000000..c26e4b7
--- /dev/null
+++ b/LeanCopilot/Models/Builtin.lean
@@ -0,0 +1,27 @@
+import ModelCheckpointManager
+import LeanCopilot.Models.ByT5
+
+set_option autoImplicit false
+
+namespace LeanCopilot.Builtin
+
+
+def generator : NativeGenerator := {
+ url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small"
+ tokenizer := ByT5.tokenizer
+ params := {
+ numReturnSequences := 32
+ }
+}
+
+
+def encoder : NativeEncoder := {
+ url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small"
+ tokenizer := ByT5.tokenizer
+}
+
+
+def premisesUrl := Url.parse! "https://huggingface.co/kaiyuy/premise-embeddings-leandojo-lean4-retriever-byt5-small"
+
+
+end LeanCopilot.Builtin
diff --git a/LeanInfer/Tokenization.lean b/LeanCopilot/Models/ByT5.lean
similarity index 85%
rename from LeanInfer/Tokenization.lean
rename to LeanCopilot/Models/ByT5.lean
index 7ec2d24..a310e0f 100644
--- a/LeanInfer/Tokenization.lean
+++ b/LeanCopilot/Models/ByT5.lean
@@ -1,8 +1,9 @@
-namespace LeanInfer
+import LeanCopilot.Models.Defs
-#eval "\t".length
+namespace LeanCopilot.ByT5
-def byt5Tokens : Array String := #[
+
+def vocab : Array String := #[
"\u0000",
"\u0001",
"\u0002",
@@ -263,23 +264,29 @@ def byt5Tokens : Array String := #[
def byteToToken (b : UInt8) : String :=
- byt5Tokens.get! b.toNat
+ vocab.get! b.toNat
def tokenToByte! (t : String) : UInt8 :=
- byt5Tokens.findIdx? (· = t) |>.get! |>.toUInt8
+ vocab.findIdx? (· = t) |>.get! |>.toUInt8
-def tokenizeByt5 (text : String) (addEOS : Bool) : List String :=
- let tokens := byteToToken <$> text.toUTF8.toList
- if addEOS then
- tokens ++ [""]
- else
- tokens
+def tokenize (text : String) : Array String :=
+ (byteToToken <$> text.toUTF8.toList).toArray
-def detokenizeByt5 (tokens : Array String) : String :=
+def detokenize (tokens : Array String) : String :=
String.fromUTF8Unchecked ⟨tokens.map tokenToByte!⟩
-end LeanInfer
+def eosToken := ""
+
+
+def tokenizer : Tokenizer := {
+ tokenize := tokenize,
+ detokenize := detokenize,
+ eosToken := eosToken
+}
+
+
+end LeanCopilot.ByT5
diff --git a/LeanCopilot/Models/Defs.lean b/LeanCopilot/Models/Defs.lean
new file mode 100644
index 0000000..617f6a0
--- /dev/null
+++ b/LeanCopilot/Models/Defs.lean
@@ -0,0 +1,170 @@
+import Lean
+import ModelCheckpointManager
+import LeanCopilot.Models.Interface
+
+set_option autoImplicit false
+
+open Lean
+open System (FilePath)
+
+namespace LeanCopilot
+
+
+inductive Device where
+ | cpu
+ | cuda
+ | auto
+deriving Repr
+
+
+instance : Inhabited Device where
+ default := .auto
+
+
+def Device.toString : Device → String
+ | Device.cpu => "cpu"
+ | Device.cuda => "cuda"
+ | Device.auto => "auto"
+
+instance : ToString Device := ⟨Device.toString⟩
+
+
+inductive ComputeType where
+ | default
+ | auto
+ | int8
+ | int8_float32
+ | int8_float16
+ | int8_bfloat16
+ | int16
+ | float16
+ | bfloat16
+ | float32
+deriving Repr
+
+
+def ComputeType.toString : ComputeType → String
+ | ComputeType.default => "default"
+ | ComputeType.auto => "auto"
+ | ComputeType.int8 => "int8"
+ | ComputeType.int8_float32 => "int8_float32"
+ | ComputeType.int8_float16 => "int8_float16"
+ | ComputeType.int8_bfloat16 => "int8_bfloat16"
+ | ComputeType.int16 => "int16"
+ | ComputeType.float16 => "float16"
+ | ComputeType.bfloat16 => "bfloat16"
+ | ComputeType.float32 => "float32"
+
+
+instance : ToString ComputeType := ⟨ComputeType.toString⟩
+
+
+structure Tokenizer where
+ tokenize : String → Array String
+ detokenize : Array String → String
+ eosToken : String
+
+
+structure NativeModel where
+ url : Url
+ device : Device := .auto
+ deviceIndex : Array UInt64 := #[0]
+ computeType : ComputeType := .default
+ tokenizer : Tokenizer
+
+
+def NativeModel.name (model : NativeModel) : String := model.url.name!
+
+
+def NativeModel.path (model : NativeModel) : IO FilePath :=
+ getModelDir model.url
+
+
+structure BeamSearchParams where
+ numReturnSequences : UInt64
+ beamSize : UInt64 := numReturnSequences
+ minLength : UInt64 := 1
+ maxLength : UInt64 := 1024
+ lengthPenalty : Float := 0.0
+ patience : Float := 2.0
+ temperature : Float := 1.0
+deriving Repr
+
+
+structure NativeGenerator extends NativeModel where
+ params : BeamSearchParams
+
+
+structure NativeEncoder extends NativeModel
+
+
+structure ExternalModel where
+ name : String
+ host : String := "localhost"
+ port : UInt16 := 23333
+deriving Inhabited, Repr
+
+
+structure ExternalGenerator extends ExternalModel
+deriving Repr
+
+
+structure ExternalRequest where
+ name : String
+ input : String
+ «prefix» : String
+deriving ToJson
+
+
+structure ExternalResponse where
+ outputs : Array (String × Float)
+deriving FromJson
+
+
+def ExternalGenerator.generate (model : ExternalGenerator) (input : String) (targetPrefix : String) : IO $ Array (String × Float) := do
+ let url := s!"http://{model.host}:{model.port}/generate"
+ let req : ExternalRequest := {
+ name := model.name,
+ input := input,
+ «prefix» := targetPrefix
+ }
+ let reqStr := (toJson req).pretty 99999999999999999
+ let out ← IO.Process.run {
+ cmd := "curl"
+ args := #["-X", "POST", url, "-H", "accept: application/json", "-H", "Content-Type: application/json", "-d", reqStr]
+ }
+
+ let some json := Json.parse out |>.toOption | throw $ IO.userError "Failed to parse response"
+ let some res := (fromJson? json : Except String ExternalResponse) |>.toOption | throw $ IO.userError "Failed to parse response"
+ return res.outputs
+
+
+instance : TextToText ExternalGenerator := ⟨ExternalGenerator.generate⟩
+
+
+structure ExternalEncoder extends ExternalModel
+deriving Repr
+
+
+def ExternalEncoder.encode (model : ExternalEncoder) (input : String) : IO FloatArray := do
+ return FloatArray.mk #[0.0]
+
+
+instance : TextToVec ExternalEncoder := ⟨ExternalEncoder.encode⟩
+
+
+structure GenericGenerator where
+ generate : String → String → IO (Array (String × Float))
+
+
+instance : TextToText GenericGenerator := ⟨GenericGenerator.generate⟩
+
+
+structure GenericEncoder where
+ encode : String → IO FloatArray
+
+
+instance : TextToVec GenericEncoder := ⟨GenericEncoder.encode⟩
+
+
+end LeanCopilot
diff --git a/LeanCopilot/Models/FFI.lean b/LeanCopilot/Models/FFI.lean
new file mode 100644
index 0000000..b8b1c1a
--- /dev/null
+++ b/LeanCopilot/Models/FFI.lean
@@ -0,0 +1,142 @@
+import LeanCopilot.Models.Defs
+import LeanCopilot.Models.Builtin
+
+namespace LeanCopilot
+
+set_option autoImplicit false
+
+namespace FFI
+
+@[extern "is_generator_initialized"]
+opaque isGeneratorInitialized : (name : @& String) → Bool
+
+@[extern "is_encoder_initialized"]
+opaque isEncoderInitialized : (name : @& String) → Bool
+
+@[extern "init_generator"]
+opaque initGenerator (name : @& String) (modelPath : @& String) (computeType : @& String) (device : @& String) (deviceIndex : @& Array UInt64) : Bool
+
+@[extern "init_encoder"]
+opaque initEncoder (name : @& String) (modelPath : @& String) (computeType : @& String) (device : @& String) (deviceIndex : @& Array UInt64) : Bool
+
+@[extern "generate"]
+opaque generate (name : @& String) (inputTokens : @& Array String) (targetPrefixTokens : @& Array String) (numReturnSequences : UInt64) (beamSize : UInt64)
+ (minLength : UInt64) (maxLength : UInt64) (lengthPenalty : Float) (patience : Float) (temperature : Float)
+ : Array (Array String × Float)
+
+@[extern "encode"]
+opaque encode (name : @& String) (inputTokens : @& Array String) : FloatArray
+
+@[extern "init_premise_embeddings"]
+opaque initPremiseEmbeddings (path : @& String) (device : @& String) : Bool
+
+@[extern "premise_embeddings_initialized"]
+opaque premiseEmbeddingsInitialized : Unit → Bool
+
+@[extern "init_premise_dictionary"]
+opaque initPremiseDictionary (path : @& String) : Bool
+
+@[extern "premise_dictionary_initialized"]
+opaque premiseDictionaryInitialized : Unit → Bool
+
+@[extern "retrieve"]
+opaque retrieve (queryEmb : @& FloatArray) (k : UInt64) : Array (String × String × String × Float)
+
+@[extern "cuda_available"]
+opaque cudaAvailable : Unit → Bool
+
+end FFI
+
+
+def cudaAvailable : Bool := FFI.cudaAvailable ()
+
+
+namespace NativeGenerator
+
+
+def generate (model : NativeGenerator) (input : String) (targetPrefix : String) : IO $ Array (String × Float) := do
+ if ¬ FFI.isGeneratorInitialized model.name then
+ let path ← model.path
+ if ¬ (← path.pathExists) then
+ throw $ IO.userError s!"Cannot find the model {model.name}. Please run `lake exe download {model.url}`."
+ let device := model.device.toString
+ let computeType := model.computeType.toString
+ if ¬ (FFI.initGenerator model.name path.toString computeType device model.deviceIndex) then
+ throw $ IO.userError s!"Failed to initialize model {model.name}"
+
+ let tokenizer := model.tokenizer
+ let inputTokens := tokenizer.tokenize input |>.push tokenizer.eosToken
+ let targetPrefixTokens := tokenizer.tokenize targetPrefix
+ let numReturnSequences := model.params.numReturnSequences
+ let beamSize := model.params.beamSize
+ let minLength := model.params.minLength
+ let maxLength := model.params.maxLength
+ let lengthPenalty := model.params.lengthPenalty
+ let patience := model.params.patience
+ let temperature := model.params.temperature
+ let tokensWithScores := FFI.generate model.name inputTokens targetPrefixTokens numReturnSequences beamSize minLength maxLength lengthPenalty patience temperature
+
+ return tokensWithScores.filterMap fun ((ts, s) : Array String × Float) =>
+ match tokenizer.detokenize ts with
+ | "aesop" => none
+ | t => some (t, s)
+
+
+
+instance : TextToText NativeGenerator where
+ generate := NativeGenerator.generate
+
+
+end NativeGenerator
+
+
+namespace NativeEncoder
+
+
+def encode (model : NativeEncoder) (input : String) : IO FloatArray := do
+ if ¬ FFI.isEncoderInitialized model.name then
+ let path ← model.path
+ if ¬ (← path.pathExists) then
+ throw $ IO.userError s!"Cannot find the model {model.name}. Please run `lake exe download {model.url}`."
+ let device := model.device.toString
+ let computeType := model.computeType.toString
+ if ¬ (FFI.initEncoder model.name path.toString computeType device model.deviceIndex) then
+ throw $ IO.userError s!"Failed to initialize model {model.name}"
+
+ let tokenizer := model.tokenizer
+ let inputTokens := tokenizer.tokenize input |>.push tokenizer.eosToken
+ return FFI.encode model.name inputTokens
+
+
+instance : TextToVec NativeEncoder where
+ encode := NativeEncoder.encode
+
+
+end NativeEncoder
+
+
+def premiseEmbeddingsInitialized : IO Bool := do
+ return FFI.premiseEmbeddingsInitialized ()
+
+
+def initPremiseEmbeddings (device : Device) : IO Bool := do
+ let path := (← getModelDir Builtin.premisesUrl) / "embeddings.npy"
+ if ¬ (← path.pathExists) then
+ throw $ IO.userError s!"Please run `lake exe download {Builtin.premisesUrl}` to download premise embeddings."
+ return false
+ return FFI.initPremiseEmbeddings path.toString device.toString
+
+
+def premiseDictionaryInitialized : IO Bool := do
+ return FFI.premiseDictionaryInitialized ()
+
+
+def initPremiseDictionary : IO Bool := do
+ let path := (← getModelDir Builtin.premisesUrl) / "dictionary.json"
+ if ¬ (← path.pathExists) then
+ throw $ IO.userError s!"Please run `lake exe download {Builtin.premisesUrl}` to download the premise dictionary."
+ return false
+ return FFI.initPremiseDictionary path.toString
+
+
+end LeanCopilot
diff --git a/LeanCopilot/Models/Interface.lean b/LeanCopilot/Models/Interface.lean
new file mode 100644
index 0000000..881027f
--- /dev/null
+++ b/LeanCopilot/Models/Interface.lean
@@ -0,0 +1,22 @@
+set_option autoImplicit false
+
+namespace LeanCopilot
+
+
+class TextToText (τ : Type) where
+ generate (model : τ) (input : String) (targetPrefix : String) : IO $ Array (String × Float)
+
+
+class TextToVec (τ : Type) where
+ encode : τ → String → IO FloatArray
+
+
+def generate {τ : Type} [TextToText τ] (model : τ) (input : String) (targetPrefix : String := "") : IO $ Array (String × Float) :=
+ TextToText.generate model input targetPrefix
+
+
+def encode {τ : Type} [TextToVec τ] (model : τ) (input : String) : IO FloatArray :=
+ TextToVec.encode model input
+
+
+end LeanCopilot
diff --git a/LeanCopilot/Models/Registry.lean b/LeanCopilot/Models/Registry.lean
new file mode 100644
index 0000000..3b674f3
--- /dev/null
+++ b/LeanCopilot/Models/Registry.lean
@@ -0,0 +1,98 @@
+import LeanCopilot.Models.Defs
+import LeanCopilot.Models.Builtin
+import LeanCopilot.Models.FFI
+import Std.Data.HashMap
+
+set_option autoImplicit false
+
+open Std
+
+namespace LeanCopilot
+
+
+inductive Generator where
+ | native : NativeGenerator → Generator
+ | external : ExternalGenerator → Generator
+ | generic : GenericGenerator → Generator
+
+
+instance : TextToText Generator where
+ generate (model : Generator) (input : String) (targetPrefix : String) :=
+ match model with
+ | .native ng => ng.generate input targetPrefix
+ | .external eg => eg.generate input targetPrefix
+ | .generic gg => gg.generate input targetPrefix
+
+
+inductive Encoder where
+ | native : NativeEncoder → Encoder
+ | external : ExternalEncoder → Encoder
+ | generic : GenericEncoder → Encoder
+
+
+instance : TextToVec Encoder where
+ encode (model : Encoder) (input : String) :=
+ match model with
+ | .native ne => ne.encode input
+ | .external ee => ee.encode input
+ | .generic ge => ge.encode input
+
+
+instance {α β : Type} [BEq α] [Hashable α] [Repr α] [Repr β] : Repr (HashMap α β) where
+ reprPrec hm n := reprPrec hm.toList n
+
+
+structure ModelRegistry where
+ generators : HashMap String Generator :=
+ HashMap.ofList [(Builtin.generator.name, .native Builtin.generator)]
+ encoders : HashMap String Encoder :=
+ HashMap.ofList [(Builtin.encoder.name, .native Builtin.encoder)]
+
+
+namespace ModelRegistry
+
+
+def generatorNames (mr : ModelRegistry) : List String :=
+ mr.generators.toList.map (·.1)
+
+
+def encoderNames (mr : ModelRegistry) : List String :=
+ mr.encoders.toList.map (·.1)
+
+
+def modelNames (mr : ModelRegistry) : List String :=
+ mr.generatorNames ++ mr.encoderNames
+
+
+end ModelRegistry
+
+
+instance : Repr ModelRegistry where
+ reprPrec mr n := reprPrec mr.modelNames n
+
+
+instance : Inhabited ModelRegistry where
+ default := {}
+
+
+initialize modelRegistryRef : IO.Ref ModelRegistry ← IO.mkRef default
+
+
+def getModelRegistry : IO ModelRegistry := modelRegistryRef.get
+
+
+def getGenerator (name : String) : IO Generator := do
+ let mr ← getModelRegistry
+ match mr.generators.find? name with
+ | some descr => return descr
+ | none => throw $ IO.userError s!"unknown generator: {name}"
+
+
+def getEncoder (name : String) : IO Encoder := do
+ let mr ← getModelRegistry
+ match mr.encoders.find? name with
+ | some descr => return descr
+ | none => throw $ IO.userError s!"unknown encoder: {name}"
+
+
+end LeanCopilot
diff --git a/LeanCopilot/Options.lean b/LeanCopilot/Options.lean
new file mode 100644
index 0000000..a49f204
--- /dev/null
+++ b/LeanCopilot/Options.lean
@@ -0,0 +1,74 @@
+import Lean
+import LeanCopilot.Models
+
+set_option autoImplicit false
+
+open Lean
+
+namespace LeanCopilot
+
+section
+
+
+variable {m : Type → Type} [Monad m] [MonadOptions m] [MonadEnv m] [MonadLift IO m]
+
+
+register_option LeanCopilot.verbose : Bool := {
+ defValue := false
+ descr := "Log various debugging information when running LeanCopilot."
+}
+
+
+def isVerbose : m Bool := do
+ match LeanCopilot.verbose.get? (← getOptions) with
+ | some true => return true
+ | _ => return false
+
+
+namespace SuggestTactics
+
+
+register_option LeanCopilot.suggest_tactics.check : Bool := {
+ defValue := true
+ descr := "Check if the generated tactics are valid or if they can prove the goal."
+}
+
+def checkTactics : CoreM Bool := do
+ match LeanCopilot.suggest_tactics.check.get? (← getOptions) with
+ | some false => return false
+ | _ => return true
+
+
+register_option LeanCopilot.suggest_tactics.model : String := {
+ defValue := Builtin.generator.name
+}
+
+
+def getGeneratorName : m String := do
+ match LeanCopilot.suggest_tactics.model.get? (← getOptions) with
+ | some n => return n
+ | _ => return Builtin.generator.name
+
+
+end SuggestTactics
+
+
+namespace SelectPremises
+
+
+register_option LeanCopilot.select_premises.k : Nat := {
+ defValue := 16
+}
+
+
+def getNumPremises : m Nat := do
+ match LeanCopilot.select_premises.k.get? (← getOptions) with
+ | some k => return k
+ | _ => return 16
+
+
+end SelectPremises
+
+end
+
+end LeanCopilot
diff --git a/LeanCopilot/Tactics.lean b/LeanCopilot/Tactics.lean
new file mode 100644
index 0000000..243b5b6
--- /dev/null
+++ b/LeanCopilot/Tactics.lean
@@ -0,0 +1,107 @@
+import Lean
+import LeanCopilot.Options
+import LeanCopilot.Frontend
+import Aesop.Util.Basic
+
+open Lean Meta Elab Tactic
+
+set_option autoImplicit false
+
+namespace LeanCopilot
+
+
+/--
+Pretty-print a list of goals.
+-/
+def ppTacticState : List MVarId → MetaM String
+ | [] => return "no goals"
+ | [g] => return (← Meta.ppGoal g).pretty
+ | goals =>
+ return (← goals.foldlM (init := "") (fun a b => do return s!"{a}\n\n{(← Meta.ppGoal b).pretty}")).trim
+
+
+/--
+Pretty-print the current tactic state.
+-/
+def getPpTacticState : TacticM String := do
+ let goals ← getUnsolvedGoals
+ ppTacticState goals
+
+
+@[implemented_by Meta.evalExpr]
+opaque evalExpr (α) (expectedType : Expr) (value : Expr) (safety := DefinitionSafety.safe) : MetaM α
+
+
+open SuggestTactics in
+/--
+Generate a list of tactic suggestions.
+-/
+def suggestTactics (targetPrefix : String) : TacticM (Array (String × Float)) := do
+ let state ← getPpTacticState
+ if ← isVerbose then
+ logInfo s!"State:\n{state}"
+ let nm ← getGeneratorName
+ let model ← getGenerator nm
+ generate model state targetPrefix
+
+
+def annotatePremise (premisesWithInfoAndScores : String × String × String × Float) : MetaM String := do
+ let (premise, path, code, _) := premisesWithInfoAndScores
+ let declName := premise.toName
+ try
+ let info ← getConstInfo declName
+ let premise_type ← Meta.ppExpr info.type
+ let some doc_str ← findDocString? (← getEnv) declName
+ | return s!"\n{premise} : {premise_type}\n"
+ return s!"\n{premise} : {premise_type}\n\n{doc_str}\n"
+ catch _ => return s!"\n{premise} needs to be imported from {path}.\n\n```\n{code}\n```\n"
+
+
+def retrieve (input : String) : TacticM (Array (String × String × String × Float)) := do
+ if ¬ (← premiseEmbeddingsInitialized) ∧ ¬ (← initPremiseEmbeddings .auto) then
+ throwError "Cannot initialize premise embeddings"
+
+ if ¬ (← premiseDictionaryInitialized) ∧ ¬ (← initPremiseDictionary) then
+ throwError "Cannot initialize premise dictionary"
+
+ let k ← SelectPremises.getNumPremises
+ let query ← encode Builtin.encoder input
+
+ return FFI.retrieve query k.toUInt64
+
+
+def selectPremises : TacticM (Array (String × String × String × Float)) := do
+ retrieve (← getPpTacticState)
+
+
+syntax "pp_state" : tactic
+syntax "suggest_tactics" : tactic
+syntax "suggest_tactics" str : tactic
+syntax "select_premises" : tactic
+
+
+macro_rules
+ | `(tactic | suggest_tactics%$tac) => `(tactic | suggest_tactics%$tac "")
+
+
+elab_rules : tactic
+ | `(tactic | pp_state) => do
+ let state ← getPpTacticState
+ logInfo state
+
+ | `(tactic | suggest_tactics%$tac $pfx:str) => do
+ let (tacticsWithScores, elapsed) ← Aesop.time $ suggestTactics pfx.getString
+ if ← isVerbose then
+ logInfo s!"{elapsed.printAsMillis} for generating {tacticsWithScores.size} tactics"
+ let tactics := tacticsWithScores.map (·.1)
+ addSuggestions tac pfx tactics.toList (← SuggestTactics.checkTactics)
+
+ | `(tactic | select_premises) => do
+ let premisesWithInfoAndScores ← selectPremises
+ let rankedPremisesWithInfoAndScores := premisesWithInfoAndScores.qsort (·.2.2.2 > ·.2.2.2)
+ let richPremises ← Meta.liftMetaM $ (rankedPremisesWithInfoAndScores.mapM annotatePremise)
+ let richPremisesExpand := richPremises.foldl (init := "") (· ++ · ++ "\n")
+ logInfo richPremisesExpand
+
+
+end LeanCopilot
diff --git a/LeanCopilotTests/HighLevelAPIs.lean b/LeanCopilotTests/HighLevelAPIs.lean
new file mode 100644
index 0000000..ed802a2
--- /dev/null
+++ b/LeanCopilotTests/HighLevelAPIs.lean
@@ -0,0 +1,40 @@
+import Lean
+import LeanCopilot
+
+open Lean Meta
+open LeanCopilot
+
+#eval (SuggestTactics.getGeneratorName : CoreM _)
+
+-- set_option LeanCopilot.verbose false
+
+#eval getModelRegistry
+
+
+-- set_option LeanCopilot.suggest_tactics.check false
+
+-- set_option LeanCopilot.suggest_tactics.model "ct2-leandojo-lean4-retriever-byt5-small"
+
+example (a b c : Nat) : a + b + c = a + c + b := by
+ suggest_tactics
+ sorry
+
+
+example (a b c : Nat) : a + b + c = a + c + b := by
+ suggest_tactics "rw" -- You may provide a prefix to constrain the generated tactics.
+ sorry
+
+
+example (a b c : Nat) : a + b + c = a + c + b := by
+ select_premises
+ sorry
+
+-- The example below wouldn't work without it.
+#init_llm_aesop
+
+example (a b c : Nat) : a + b + c = c + b + a := by
+ aesop?
+
+
+example (a b c : Nat) : a + b + c = c + b + a := by
+ search_proof
diff --git a/LeanCopilotTests/LowLevelAPIs.lean b/LeanCopilotTests/LowLevelAPIs.lean
new file mode 100644
index 0000000..79e2768
--- /dev/null
+++ b/LeanCopilotTests/LowLevelAPIs.lean
@@ -0,0 +1,176 @@
+import LeanCopilot
+
+set_option autoImplicit false
+
+open LeanCopilot
+
+#eval cudaAvailable
+
+/-
+```python
+from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+
+tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")
+model = AutoModelForSeq2SeqLM.from_pretrained("kaiyuy/leandojo-lean4-tacgen-byt5-small")
+
+state = "n : ℕ\n⊢ gcd n n = n"
+tokenized_state = tokenizer(state, return_tensors="pt")
+
+# Generate a single tactic.
+tactic_ids = model.generate(tokenized_state.input_ids, max_length=1024)
+tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
+print(tactic, end="\n\n")
+
+# Generate multiple tactics via beam search.
+tactic_candidates_ids = model.generate(
+ tokenized_state.input_ids,
+ max_length=1024,
+ num_beams=4,
+ length_penalty=0.0,
+ do_sample=False,
+ num_return_sequences=4,
+ early_stopping=False,
+)
+tactic_candidates = tokenizer.batch_decode(
+ tactic_candidates_ids, skip_special_tokens=True
+)
+for tac in tactic_candidates:
+ print(tac)
+```
+-/
+
+
+def model₁ : NativeGenerator := {
+ url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small"
+ tokenizer := ByT5.tokenizer
+ params := {
+ numReturnSequences := 1
+ }
+}
+
+#eval generate model₁ "n : ℕ\n⊢ gcd n n = n"
+
+
+def model₁' : NativeGenerator := {model₁ with params := {numReturnSequences := 4}}
+
+#eval generate model₁' "n : ℕ\n⊢ gcd n n = n"
+
+
+def model₁'' : NativeGenerator := {
+ url := Url.parse! "https://huggingface.co/kaiyuy/ct2-byt5-small"
+ tokenizer := ByT5.tokenizer
+ params := {
+ numReturnSequences := 1
+ }
+}
+
+#eval generate model₁'' "Hello, world!"
+
+
+/-
+```python
+from transformers import AutoTokenizer, T5EncoderModel
+
+tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")
+model = T5EncoderModel.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")
+
+state = "n : ℕ\n⊢ gcd n n = n"
+tokenized_state = tokenizer(state, return_tensors="pt")
+hidden_state = model(tokenized_state.input_ids).last_hidden_state
+feature = hidden_state.mean(dim=1).squeeze()
+print(feature)
+```
+-/
+
+
+def model₂ : NativeEncoder := {
+ url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small"
+ tokenizer := ByT5.tokenizer
+}
+
+#eval encode model₂ "n : ℕ\n⊢ gcd n n = n"
+
+
+/-
+```python
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")
+model = AutoModelForCausalLM.from_pretrained("EleutherAI/llemma_7b")
+
+state = "n : ℕ\n⊢ gcd n n = n"
+tokenized_state = tokenizer(state, return_tensors="pt")
+
+# Generate a single tactic.
+tactic_ids = model.generate(tokenized_state.input_ids, max_length=32)
+tactic = tokenizer.decode(tactic_ids[0], skip_special_tokens=True)
+print(tactic, end="\n\n")
+
+# Generate multiple tactics via beam search.
+tactic_candidates_ids = model.generate(
+ tokenized_state.input_ids,
+ max_length=32,
+ num_beams=2,
+ length_penalty=0.0,
+ do_sample=False,
+ num_return_sequences=2,
+ early_stopping=False,
+)
+tactic_candidates = tokenizer.batch_decode(
+ tactic_candidates_ids, skip_special_tokens=True
+)
+for tac in tactic_candidates:
+ print(tac)
+```
+-/
+
+def model₃ : NativeGenerator := {
+ url := Url.parse! "https://huggingface.co/EleutherAI/llemma_7b"
+ tokenizer := sorry
+ params := {
+ numReturnSequences := 1
+ }
+}
+
+#eval generate model₃ "[GOAL]\nn : ℕ\n⊢ gcd n n = n\n[PROOFSTEP]\n"
+
+
+def model₃' : NativeGenerator := {model₃ with params := {numReturnSequences := 2}}
+
+#eval generate model₃' "[GOAL]\nn : ℕ\n⊢ gcd n n = n\n[PROOFSTEP]\n"
+
+
+structure DummyGenerator where
+ outputs : Array (String × Float)
+
+
+instance : TextToText DummyGenerator where
+ generate model _ _ := return model.outputs
+
+
+def model₄ : DummyGenerator := ⟨#[⟨"Hello, world!", 0.5⟩, ("Hi!", 0.3)]⟩
+
+#eval generate model₄ "n : ℕ\n⊢ gcd n n = n"
+
+
+structure DummyEncoder where
+ output : FloatArray
+
+
+instance : TextToVec DummyEncoder where
+ encode model _ := return model.output
+
+
+def model₅ : DummyEncoder := ⟨FloatArray.mk #[1, 2, 3]⟩
+
+#eval encode model₅ "Hi!"
+
+
+def model₆ : ExternalGenerator := {
+ name := "EleutherAI/llemma_7b"
+ host := "localhost"
+ port := 23333
+}
+
+-- Go to ./python and run `uvicorn server:app --port 23333`
+#eval generate model₆ "[GOAL]\nn : ℕ\n⊢ gcd n n = n\n[PROOFSTEP]\n" "apply"
diff --git a/LeanInfer.lean b/LeanInfer.lean
deleted file mode 100644
index 8619907..0000000
--- a/LeanInfer.lean
+++ /dev/null
@@ -1,5 +0,0 @@
-import LeanInfer.Config
-import LeanInfer.Frontend
-import LeanInfer.Basic
-import LeanInfer.Tactics
-import LeanInfer.LlmAesop
\ No newline at end of file
diff --git a/LeanInfer/Basic.lean b/LeanInfer/Basic.lean
deleted file mode 100644
index f56c553..0000000
--- a/LeanInfer/Basic.lean
+++ /dev/null
@@ -1,142 +0,0 @@
-import Lean
-import LeanInfer.Cache
-import LeanInfer.FFI
-import LeanInfer.Config
-import LeanInfer.Tokenization
-
-open Lean
-
-set_option autoImplicit false
-
-
-namespace LeanInfer
-
-section
-
-
-variable {m : Type → Type} [Monad m] [MonadLog m] [AddMessageContext m]
- [MonadOptions m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT IO m] [MonadError m]
-
-
-register_option LeanInfer.verbose : Bool := {
- defValue := false
- descr := "Log various debugging information when running LeanInfer."
-}
-
-
-def isVerbose : m Bool := do
- match LeanInfer.verbose.get? (← getOptions) with
- | some true => return true
- | _ => return false
-
-
-private def isGeneratorInitialized : m Bool := do
- match ← getBackend with
- | .native (.onnx _) => return FFI.isOnnxGeneratorInitialized ()
- | .native (.ct2 _) => return FFI.isCt2GeneratorInitialized ()
- | .ipc .. => unreachable!
-
-
-def initGenerator : IO Bool := do
- let dir ← Cache.getGeneratorDir
- if ¬ (← dir.pathExists) then
- throw $ IO.userError "Cannot find the generator model. Please run `lake script run LeanInfer/download`."
- return false
-
- match ← getBackend with
- | .native (.onnx _) =>
- assert! FFI.initOnnxGenerator dir.toString
- | .native (.ct2 params) =>
- assert! FFI.initCt2Generator dir.toString params.device params.computeType params.deviceIndex params.intraThreads
- | .ipc .. => unreachable!
-
- return true
-
-
-def generate (input : String) (targetPrefix : String) : m (Array (String × Float)) := do
- if ¬ (← isGeneratorInitialized) ∧ ¬ (← initGenerator) then
- return #[]
-
- let config ← getConfig
- let tacticsWithScores := match config.backend with
- | .native (.onnx _) =>
- let numReturnSequences := config.decoding.numReturnSequences
- let maxLength := config.decoding.maxLength
- let temperature := config.decoding.temperature
- let beamSize := config.decoding.beamSize
- let rawOutputs := FFI.onnxGenerate input numReturnSequences maxLength temperature beamSize
- rawOutputs.filter fun (entry : String × Float) => entry.fst ≠ "aesop"
- | .native (.ct2 _) =>
- let inputTokens := tokenizeByt5 input true |>.toArray
- let targetPrefixTokens := tokenizeByt5 targetPrefix false |>.toArray
- let numReturnSequences := config.decoding.numReturnSequences
- let beamSize := config.decoding.beamSize
- let minLength := config.decoding.minLength
- let maxLength := config.decoding.maxLength
- let lengthPenalty := config.decoding.lengthPenalty
- let patience := config.decoding.patience
- let temperature := config.decoding.temperature
- let tokensWithScores := FFI.ct2Generate inputTokens targetPrefixTokens numReturnSequences beamSize minLength maxLength lengthPenalty patience temperature
- tokensWithScores.filterMap fun (ts, s) => match detokenizeByt5 ts with
- | "aesop" => none
- | t => some (t, s)
- | .ipc .. => unreachable!
-
- let rankedTactics := tacticsWithScores.qsort (·.2 > ·.2)
- if ← isVerbose then
- logInfo $ rankedTactics.foldl (init := "Generated tactics with scores:\n")
- fun acc (t, s) => acc ++ s!" {t}: {s}\n"
- return rankedTactics
-
-
-private def isEncoderInitialized : m Bool := do
- match ← getBackend with
- | .native (.onnx _) => return unreachable!
- | .native (.ct2 _) => return FFI.isCt2EncoderInitialized ()
- | .ipc .. => unreachable!
-
-
-def initEncoder : IO Bool := do
- let dir ← Cache.getEncoderDir
- if ¬ (← dir.pathExists) then
- throw $ IO.userError "Cannot find the encoder model. Please run `lake script run LeanInfer/download`."
- return false
-
- match ← getBackend with
- | .native (.onnx _) => unreachable!
- | .native (.ct2 _) => assert! FFI.initCt2Encoder dir.toString
- | .ipc .. => unreachable!
-
- return true
-
-
-def encode (input : String) : m FloatArray := do
- if ¬ (← isEncoderInitialized) ∧ ¬ (← initEncoder) then
- return FloatArray.mk #[]
-
- match ← getBackend with
- | .native (.onnx _) => unreachable!
- | .native (.ct2 _) =>
- let inputTokens := tokenizeByt5 input true |>.toArray
- return FFI.ct2Encode inputTokens
- | .ipc .. => unreachable!
-
-
-def retrieve (input : String) : m (Array (String × Float)) := do
- let query ← encode input
- logInfo s!"{query}"
- return #[("NotImplemented", 0.5)]
-
-end
-
-
-def setConfig (config : Config) : CoreM Unit := do
- assert! config.isValid
- configRef.modify fun _ => config
- if ← isGeneratorInitialized then
- assert! ← initGenerator
- if ← isEncoderInitialized then
- assert! ← initEncoder
-
-
-end LeanInfer
diff --git a/LeanInfer/Cache.lean b/LeanInfer/Cache.lean
deleted file mode 100644
index 67fef74..0000000
--- a/LeanInfer/Cache.lean
+++ /dev/null
@@ -1,51 +0,0 @@
-import Lean
-import LeanInfer.Config
-
-open Lean System
-
-namespace LeanInfer.Cache
-
-
-private def getHomeDir : IO FilePath := do
- let some dir ← IO.getEnv "HOME" | throw $ IO.userError "Cannot find the $HOME environment variable."
- return dir
-
-
-private def ensureDirExists (dir : FilePath) : IO Unit := do
- if !(← dir.pathExists) then
- IO.FS.createDirAll dir
-
-
-def getDefaultCacheDir : IO FilePath := do
- return (← getHomeDir) / ".cache" / "lean_infer"
-
-
-def getCacheDir : IO FilePath := do
- let defaultCacheDir ← getDefaultCacheDir
- let dir := match ← IO.getEnv "LEAN_INFER_CACHE_DIR" with
- | some dir => (dir : FilePath)
- | none => defaultCacheDir
- ensureDirExists dir
- return dir.normalize
-
-
-private def getModelDir (url : HuggingFaceURL) : IO FilePath := do
- let cacheDir ← getCacheDir
- let dir := match url.user with
- | none => cacheDir / url.modelName
- | some user => cacheDir / user / url.modelName
- return dir.normalize
-
-
-/--
-Return the cache directory for storing the current model.
--/
-def getGeneratorDir : IO FilePath := do
- getModelDir (← getGeneratorUrl)
-
-
-def getEncoderDir : IO FilePath := do
- getModelDir (← getEncoderUrl)
-
-
-end LeanInfer.Cache
diff --git a/LeanInfer/Config.lean b/LeanInfer/Config.lean
deleted file mode 100644
index f25501b..0000000
--- a/LeanInfer/Config.lean
+++ /dev/null
@@ -1,131 +0,0 @@
-import Lean
-import LeanInfer.Url
-
-open Lean
-
-set_option autoImplicit false
-
-namespace LeanInfer
-
-structure OnnxParams where
- generatorUrl : HuggingFaceURL
- encoderUrl : HuggingFaceURL
-deriving Repr
-
-def OnnxParams.isValid (params : OnnxParams) : Bool :=
- params.generatorUrl.isValid ∧ params.encoderUrl.isValid
-
--- https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#translator
-structure CTranslate2Params where
- generatorUrl : HuggingFaceURL
- encoderUrl : HuggingFaceURL
- device : String := "auto"
- deviceIndex : Array UInt64 := #[0]
- computeType : String := "auto"
- -- interThreads : UInt64 := 1
- intraThreads : UInt64 := 0
-deriving Repr
-
-def isValidDevice (device : String) : Bool :=
- #["cpu", "cuda", "auto"].contains device
-
-def isValidComputeType (computeType : String) : Bool :=
- #["default", "auto", "int8", "int8_float32", "int8_float16", "int8_bfloat16", "int16", "float16", "bfloat16", "float32"].contains computeType
-
-def CTranslate2Params.isValid (params : CTranslate2Params) : Bool :=
- params.generatorUrl.isValid ∧ params.encoderUrl.isValid ∧ isValidDevice params.device ∧ isValidComputeType params.computeType
-
-inductive NativeBackend where
- | onnx : OnnxParams → NativeBackend
- | ct2 : CTranslate2Params → NativeBackend
-deriving Repr
-
-def NativeBackend.isValid : NativeBackend → Bool
- | .onnx params => params.isValid
- | .ct2 params => params.isValid
-
-inductive IpcBackend where
- | ct2 : CTranslate2Params → IpcBackend
- | external (host : String) (port : UInt64) : IpcBackend
-deriving Repr
-
-def IpcBackend.isValid : IpcBackend → Bool
- | .ct2 params => params.isValid
- | .external .. => true
-
-inductive Backend where
- | native : NativeBackend → Backend
- | ipc : IpcBackend → Backend
-deriving Repr
-
-def Backend.isValid : Backend → Bool
- | .native b => b.isValid
- | .ipc b => b.isValid
-
-structure DecodingParams where
- numReturnSequences : UInt64
- beamSize : UInt64 := numReturnSequences
- minLength : UInt64 := 1
- maxLength : UInt64 := 1024
- lengthPenalty : Float := 0.0
- patience : Float := 2.0
- temperature : Float := 1.0
-deriving Repr
-
-def DecodingParams.isValid (params : DecodingParams) : Bool :=
- params.numReturnSequences ≥ 1 ∧ params.beamSize ≥ 1 ∧ params.minLength ≥ 0 ∧
- params.maxLength ≥ params.minLength ∧ params.patience ≥ 1.0 ∧ params.temperature ≥ 0.0
-
-structure Config where
- backend : Backend
- decoding : DecodingParams
-deriving Repr
-
-def Config.isValid (config : Config) : Bool :=
- config.backend.isValid ∧ config.decoding.isValid
-
-def safeConfig : Config := {
- backend := .native $ .ct2 {
- generatorUrl := ⟨"kaiyuy", "ct2-leandojo-lean4-tacgen-byt5-small"⟩,
- encoderUrl := ⟨"kaiyuy", "ct2-leandojo-lean4-retriever-byt5-small"⟩,
- },
- decoding := {
- numReturnSequences := 32,
- }
-}
-
-instance : Inhabited Config := ⟨safeConfig⟩
-
-def autoConfig : IO Config := do
- return safeConfig
-
-initialize configRef : IO.Ref Config ← IO.mkRef (← autoConfig)
-
-section
-
-variable {m : Type → Type} [Monad m] [MonadLiftT IO m] [MonadLiftT (ST IO.RealWorld) m]
-
-def getConfig : IO Config := configRef.get
-
-def getBackend : m Backend := do
- return (← getConfig).backend
-
-def getDecodingParams : m DecodingParams := do
- return (← getConfig).decoding
-
-def getGeneratorUrl : m HuggingFaceURL := do
- match ← getBackend with
- | .native (.onnx params) => return params.generatorUrl
- | .native (.ct2 params) => return params.generatorUrl
- | .ipc _ => return unreachable!
-
-
-def getEncoderUrl : m HuggingFaceURL := do
- match ← getBackend with
- | .native (.onnx params) => return params.encoderUrl
- | .native (.ct2 params) => return params.encoderUrl
- | .ipc _ => return unreachable!
-
-end
-
-end LeanInfer
diff --git a/LeanInfer/FFI.lean b/LeanInfer/FFI.lean
deleted file mode 100644
index 15fc2d1..0000000
--- a/LeanInfer/FFI.lean
+++ /dev/null
@@ -1,33 +0,0 @@
-namespace LeanInfer.FFI
-
-@[extern "init_onnx_generator"]
-opaque initOnnxGenerator (modelPath : @& String) : Bool
-
-@[extern "is_onnx_generator_initialized"]
-opaque isOnnxGeneratorInitialized : Unit → Bool
-
-@[extern "onnx_generate"]
-opaque onnxGenerate (input : @& String) (numReturnSequences : UInt64) (maxLength : UInt64)
-(temperature : Float) (beamSize : UInt64) : Array (String × Float)
-
-@[extern "init_ct2_generator"]
-opaque initCt2Generator (modelPath : @& String) (device : @& String) (computeType : @& String) (deviceIndex : @& Array UInt64) (intraThreads : UInt64) : Bool
-
-@[extern "is_ct2_generator_initialized"]
-opaque isCt2GeneratorInitialized : Unit → Bool
-
-@[extern "ct2_generate"]
-opaque ct2Generate (inputTokens : @& Array String) (targetPrefixTokens : @& Array String) (numReturnSequences : UInt64) (beamSize : UInt64)
- (minLength : UInt64) (maxLength : UInt64) (lengthPenalty : Float) (patience : Float) (temperature : Float)
- : Array (Array String × Float)
-
-@[extern "init_ct2_encoder"]
-opaque initCt2Encoder (modelPath : @& String) : Bool
-
-@[extern "is_ct2_encoder_initialized"]
-opaque isCt2EncoderInitialized : Unit → Bool
-
-@[extern "ct2_encode"]
-opaque ct2Encode (inputTokens : @& Array String) : FloatArray
-
-end LeanInfer.FFI
diff --git a/LeanInfer/LlmAesop.lean b/LeanInfer/LlmAesop.lean
deleted file mode 100644
index 9a40614..0000000
--- a/LeanInfer/LlmAesop.lean
+++ /dev/null
@@ -1,17 +0,0 @@
-import LeanInfer.Tactics
-import Aesop
-
-open Lean Lean.Elab.Command
-
-namespace LeanInfer
-
-
-def tacGen : Aesop.TacGen := fun (mvarId : MVarId) => do
- let state ← ppTacticState [mvarId]
- generate state ""
-
-
-macro "#init_llm_aesop" : command => `(#eval (initGenerator : IO Bool) @[aesop 100%] def tacGen := LeanInfer.tacGen #eval getConfig)
-
-
-end LeanInfer
diff --git a/LeanInfer/Tactics.lean b/LeanInfer/Tactics.lean
deleted file mode 100644
index 938ca5c..0000000
--- a/LeanInfer/Tactics.lean
+++ /dev/null
@@ -1,84 +0,0 @@
-import Lean
-import LeanInfer.Basic
-import LeanInfer.Frontend
-import Aesop.Util.Basic
-
-open Lean Elab Tactic
-
-set_option autoImplicit false
-
-namespace LeanInfer
-
-
-register_option LeanInfer.suggest_tactics.check : Bool := {
- defValue := true
- descr := "Check if the generated tactics are valid or if they can prove the goal."
-}
-
-
-def checkTactics : CoreM Bool := do
- match LeanInfer.suggest_tactics.check.get? (← getOptions) with
- | some false => return false
- | _ => return true
-
-
-def ppTacticState : List MVarId → MetaM String
- | [] => return "no goals"
- | [g] => return (← Meta.ppGoal g).pretty
- | goals =>
- return (← goals.foldlM (init := "") (fun a b => do return s!"{a}\n\n{(← Meta.ppGoal b).pretty}")).trim
-
-
-def getPpTacticState : TacticM String := do
- let goals ← getUnsolvedGoals
- ppTacticState goals
-
-
-def suggestTactics (targetPrefix : String) : TacticM (Array (String × Float)) := do
- let state ← getPpTacticState
- if ← isVerbose then
- logInfo s!"State:\n{state}"
- generate state targetPrefix
-
-
-def selectPremises : TacticM (Array (String × Float)) := do
- retrieve (← getPpTacticState)
-
-
-syntax "trace_generate" str : tactic
-syntax "trace_encode" str : tactic
-syntax "pp_state" : tactic
-syntax "suggest_tactics" : tactic
-syntax "suggest_tactics" str : tactic
-syntax "select_premises" : tactic
-
-
-macro_rules
- | `(tactic | suggest_tactics%$tac) => `(tactic | suggest_tactics%$tac "")
-
-
-elab_rules : tactic
- | `(tactic | trace_generate $input:str) => do
- logInfo s!"{← generate input.getString ""}"
-
- | `(tactic | trace_encode $input:str) => do
- logInfo s!"{← encode input.getString}"
-
- | `(tactic | pp_state) => do
- let state ← getPpTacticState
- logInfo state
-
- | `(tactic | suggest_tactics%$tac $pfx:str) => do
- let (tacticsWithScores, elapsed) ← Aesop.time $ suggestTactics pfx.getString
- if ← isVerbose then
- logInfo s!"{elapsed.printAsMillis} for generating {tacticsWithScores.size} tactics"
- let tactics := tacticsWithScores.map (·.1)
- addSuggestions tac pfx tactics.toList (← checkTactics)
-
- | `(tactic | select_premises) => do
- let premisesWithScores ← selectPremises
- let premises := premisesWithScores.map (·.1)
- logInfo s!"{premises}"
-
-
-end LeanInfer
diff --git a/LeanInfer/Url.lean b/LeanInfer/Url.lean
deleted file mode 100644
index 8a03d66..0000000
--- a/LeanInfer/Url.lean
+++ /dev/null
@@ -1,31 +0,0 @@
-import Lean
-
-open System (FilePath)
-
-set_option autoImplicit false
-
-namespace LeanInfer
-
-def HF_BASE_URL := "https://huggingface.co"
-
-structure HuggingFaceURL where
- user : Option String
- modelName : String
-deriving Inhabited
-
-instance : ToString HuggingFaceURL where
- toString url := match url.user with
- | none => s!"{HF_BASE_URL}/{url.modelName}"
- | some user => s!"{HF_BASE_URL}/{user}/{url.modelName}"
-
-instance : Repr HuggingFaceURL where
- reprPrec url x := reprPrec (toString url) x
-
-def HuggingFaceURL.isValid (url : HuggingFaceURL) : Bool :=
- let validModelName := ¬ url.modelName.isEmpty ∧ ¬ url.modelName.contains '/'
- let validUser : Bool := match url.user with
- | none => true
- | some username => ¬ username.isEmpty ∧ ¬ username.contains '/'
- validModelName ∧ validUser
-
-end LeanInfer
diff --git a/LeanInferTests/Aesop.lean b/LeanInferTests/Aesop.lean
deleted file mode 100644
index 7cc044d..0000000
--- a/LeanInferTests/Aesop.lean
+++ /dev/null
@@ -1,8 +0,0 @@
-import LeanInfer
-import Aesop
-
--- The example below wouldn't work without it.
-#init_llm_aesop
-
-example (a b c : Nat) : a + b + c = c + b + a := by
- aesop?
diff --git a/LeanInferTests/Examples.lean b/LeanInferTests/Examples.lean
deleted file mode 100644
index ff8d33a..0000000
--- a/LeanInferTests/Examples.lean
+++ /dev/null
@@ -1,40 +0,0 @@
-import LeanInfer
-
-open LeanInfer
-
-/-
-#eval getConfig
-
-def cfg : Config := {
- backend := .native $ .ct2 {
- generatorUrl := some ⟨"kaiyuy", "ct2-leandojo-lean4-tacgen-byt5-small"⟩,
- encoderUrl := some ⟨"kaiyuy", "ct2-leandojo-lean4-retriever-byt5-small"⟩
- },
- decoding := {numReturnSequences := 64}
-}
-
-#eval setConfig cfg
--/
-
-/-
-example (n : Nat) : Nat.gcd n n = n := by
- select_premises!
- sorry
--/
-
--- set_option LeanInfer.verbose false
--- set_option LeanInfer.suggest_tactics.check true
-
-example (a b c : Nat) : a + b + c = a + c + b := by
- suggest_tactics
- sorry
-
-
-example (a b c : Nat) : a + b + c = a + c + b := by
- suggest_tactics "rw" -- You may provide a prefix to constrain the generated tactics.
- sorry
-
-
-example (a b c : Nat) : a + b + c = a + c + b := by
- select_premises
- sorry
diff --git a/ModelCheckpointManager.lean b/ModelCheckpointManager.lean
new file mode 100644
index 0000000..ad0cda2
--- /dev/null
+++ b/ModelCheckpointManager.lean
@@ -0,0 +1,2 @@
+import ModelCheckpointManager.Url
+import ModelCheckpointManager.Download
diff --git a/ModelCheckpointManager/Download.lean b/ModelCheckpointManager/Download.lean
new file mode 100644
index 0000000..9df8e45
--- /dev/null
+++ b/ModelCheckpointManager/Download.lean
@@ -0,0 +1,105 @@
+import ModelCheckpointManager.Url
+
+set_option autoImplicit false
+
+open System (FilePath)
+
+namespace LeanCopilot
+
+def ensureDirExists (dir : FilePath) : IO Unit := do
+ if ¬ (← dir.pathExists) then
+ IO.FS.createDirAll dir
+
+
+-- TODO: Not sure if this works for Windows.
+def getHomeDir : IO FilePath := do
+ let some dir ← IO.getEnv "HOME" | throw $ IO.userError "Cannot find the $HOME environment variable."
+ return dir
+
+
+def getDefaultCacheDir : IO FilePath := do
+ return (← getHomeDir) / ".cache/lean_copilot/models"
+
+
+def getCacheDir : IO FilePath := do
+ let defaultCacheDir ← getDefaultCacheDir
+ let dir := match ← IO.getEnv "LEAN_COPILOT_CACHE_DIR" with
+ | some dir => (dir : FilePath)
+ | none => defaultCacheDir
+ ensureDirExists dir
+ return dir.normalize
+
+
+inductive ModelPath where
+ | «local» : FilePath → ModelPath
+ | remote : Url → ModelPath
+
+
+def getModelDir (url : Url) : IO FilePath := do
+ return (← getCacheDir) / url.hostname / url.path |>.normalize
+
+
+def isUpToDate (url : Url) : IO Bool := do
+ let dir := ← getModelDir url
+ if ¬ (← dir.pathExists) then
+ return false
+
+ let _ ← IO.Process.run {
+ cmd := "git"
+ args := #["fetch", "--quiet", "--all"]
+ cwd := dir
+ }
+
+ let branch := (← IO.Process.run {
+ cmd := "git"
+ args := #["symbolic-ref", "refs/remotes/origin/HEAD","--short"]
+ cwd := dir
+ }).trim
+
+ let hasRemoteChange := (← IO.Process.run {
+ cmd := "git"
+ args := #["diff", (branch.splitOn "/")[1]!, branch, "--shortstat"]
+ cwd := dir
+ }).trim != ""
+
+ let hasLocalChange := (← IO.Process.run {
+ cmd := "git"
+ args := #["diff", "--shortstat"]
+ cwd := dir
+ }).trim != ""
+
+ return ¬ (hasRemoteChange ∨ hasLocalChange)
+
+
+def initGitLFS : IO Unit := do
+ let proc ← IO.Process.output {
+ cmd := "git"
+ args := #["lfs", "install"]
+ }
+ if proc.exitCode != 0 then
+ throw $ IO.userError "Failed to initialize Git LFS. Please install it."
+
+
+def downloadUnlessUpToDate (url : Url) : IO Unit := do
+ let dir := ← getModelDir url
+ if ← isUpToDate url then
+ println! s!"The model is available at {dir}"
+ return
+
+ println! s!"Downloading the model into {dir}"
+ if ← dir.pathExists then
+ IO.FS.removeDirAll dir
+ let some parentDir := dir.parent | unreachable!
+ IO.FS.createDirAll parentDir
+
+ initGitLFS
+ let proc ← IO.Process.output {
+ cmd := "git"
+ args := #["clone", toString url]
+ cwd := parentDir
+ }
+ if proc.exitCode != 0 then
+ throw $ IO.userError s!"Failed to download the model. You download it manually from {url} and store it in `{dir}/`. See https://huggingface.co/docs/hub/models-downloading for details."
+
+
+end LeanCopilot
diff --git a/ModelCheckpointManager/Main.lean b/ModelCheckpointManager/Main.lean
new file mode 100644
index 0000000..ac29f6d
--- /dev/null
+++ b/ModelCheckpointManager/Main.lean
@@ -0,0 +1,25 @@
+import ModelCheckpointManager.Url
+import ModelCheckpointManager.Download
+
+open LeanCopilot
+
+def builtinModelUrls : List String := [
+ "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small",
+ "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small",
+ "https://huggingface.co/kaiyuy/premise-embeddings-leandojo-lean4-retriever-byt5-small"
+]
+
+
+def main (args : List String) : IO Unit := do
+ let mut tasks := #[]
+ let urls := Url.parse! <$> (if args.isEmpty then builtinModelUrls else args)
+
+ for url in urls do
+ tasks := tasks.push $ ← IO.asTask $ downloadUnlessUpToDate url
+
+ for t in tasks do
+ match ← IO.wait t with
+ | Except.error e => throw e
+ | Except.ok _ => pure ()
+
+ println! "Done!"
diff --git a/ModelCheckpointManager/Url.lean b/ModelCheckpointManager/Url.lean
new file mode 100644
index 0000000..a6930d0
--- /dev/null
+++ b/ModelCheckpointManager/Url.lean
@@ -0,0 +1,78 @@
+import Lean
+
+open System (FilePath)
+
+set_option autoImplicit false
+
+namespace LeanCopilot
+
+
+structure Url where
+ protocol : String
+ hostname : String
+ path : FilePath
+deriving Inhabited, Repr
+
+
+namespace Url
+
+def isValid (url : Url) : Bool :=
+ ¬ url.protocol.isEmpty ∧ ¬ url.hostname.isEmpty ∧ ¬ url.path.toString.isEmpty ∧ url.path.isRelative ∧ url.path.fileName.isSome
+
+
+def toString (url : Url) : String :=
+ assert! isValid url
+ s!"{url.protocol}://{url.hostname}/{url.path}"
+
+
+instance : ToString Url := ⟨toString⟩
+
+
+def parse (s : String) : Option Url :=
+ let parts := s.splitOn "://"
+ if h : parts.length != 2 then
+ none
+ else
+ have : parts.length > 1 := by
+ by_cases h' : parts.length = 2
+ · rw [h']
+ apply Nat.lt_succ_of_le
+ simp
+ · simp_all
+ have : parts.length > 0 := by
+ apply Nat.lt_of_succ_lt
+ assumption
+ let protocol := parts[0]
+ match parts[1].splitOn "/" with
+ | hostname :: path =>
+ let path := FilePath.mk $ "/".intercalate path
+ let url : Url := ⟨protocol, hostname, path⟩
+ if url.isValid then
+ some url
+ else
+ none
+ | _ => none
+
+
+def parse! (s : String) : Url :=
+ match parse s with
+ | some url => url
+ | none => panic! "Invalid url: {s}"
+
+
+def name! (url : Url) : String :=
+ url.path.fileName.get!
+
+
+private def url₁ := parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small"
+private def url₂ := parse! "https://huggingface.co/bert-base-uncased"
+
+#eval url₁
+#eval url₂
+
+#eval url₁.name!
+#eval url₂.name!
+
+end Url
+
+end LeanCopilot
diff --git a/README.md b/README.md
index 188eb02..a113777 100644
--- a/README.md
+++ b/README.md
@@ -1,55 +1,61 @@
-LeanInfer: Native Neural Network Inference in Lean 4
+LeanCopilot: Native Neural Network Inference in Lean 4
=============================================
-
+
-LeanInfer provides tactic suggestions by running LLMs through Lean's foreign function interface (FFI). It is in an early stage of development. In the long term, we aim to integrate Lean and machine learning by providing a general and efficient way to run the inference of neural networks in Lean.
+LeanCopilot provides tactic suggestions by running LLMs through Lean's foreign function interface (FFI). It is in an early stage of development. In the long term, we aim to integrate Lean and machine learning by providing a general and efficient way to run the inference of neural networks in Lean.
## Requirements
-* Supported platforms: Linux and macOS (:warning: maybe also Windows WSL, but untested)
+* Supported platforms: Linux (including Windows WSL) and macOS
* Git LFS
* Optional (recommended if you have a [CUDA-enabled GPU](https://developer.nvidia.com/cuda-gpus)): CUDA and [cuDNN](https://developer.nvidia.com/cudnn)
-## Adding LeanInfer as a Dependency to Your Project
+## Adding LeanCopilot as a Dependency to Your Project
:warning: Your package must use a Lean version of at least `lean4:v4.3.0-rc2`.
-1. Add the package configuration option `moreLinkArgs := #["-L./.lake/packages/LeanInfer/.lake/build/lib", "-lonnxruntime", "-lctranslate2"]` to lakefile.lean. Also add LeanInfer as a dependency:
+1. Add the package configuration option `moreLinkArgs := #["-L./.lake/packages/LeanCopilot/.lake/build/lib", "-lonnxruntime", "-lctranslate2"]` to lakefile.lean. Also add LeanCopilot as a dependency:
```lean
require LeanInfer from git "https://github.com/lean-dojo/LeanInfer.git" @ "v0.1.0"
```
-2. Run `lake update LeanInfer`
-3. Run `lake script run LeanInfer/download` to download the models from Hugging Face to `~/.cache/lean_infer/`
+2. Run `lake update LeanCopilot`
+3. Run `lake script run LeanCopilot/download` to download the models from Hugging Face to `~/.cache/lean_copilot/`
4. Run `lake build`
-You may also see the [example here](https://github.com/yangky11/lean4-example/blob/LeanInfer-demo). If you have problems building the project, our [Dockerfile](./Dockerfile), [build.sh](scripts/build.sh) or [build_example.sh](scripts/build_example.sh) may be helpful.
+You may also see the [example here](https://github.com/yangky11/lean4-example/blob/LeanCopilot-demo). If you have problems building the project, our [Dockerfile](./Dockerfile), [build.sh](scripts/build.sh) or [build_example.sh](scripts/build_example.sh) may be helpful.
-## Using LeanInfer
+## Using LeanCopilot
### Generating Tactic Suggestions
-After `import LeanInfer`, you can use the tactic `suggest_tactics` to generate tactic suggestions (see the image above and [this example](LeanInferTests/Examples.lean)). You can click on any of the suggested tactics to use it in the proof.
+After `import LeanCopilot`, you can use the tactic `suggest_tactics` to generate tactic suggestions (see the image above and [this example](LeanCopilotTests/Examples.lean)). You can click on any of the suggested tactics to use it in the proof.
You may provide a prefix to constrain the generated tactics. For example, `suggest_tactics "rw"` would only generate tactics starting with `rw`.
### Searching for Proofs
-You can combine the LLM-generated tactic suggestions with [aesop](https://github.com/leanprover-community/aesop) to search for complete proofs. To do this, simply add `#init_llm_aesop` before using aesop (see [this example](LeanInferTests/Aesop.lean)).
+You can combine the LLM-generated tactic suggestions with [aesop](https://github.com/leanprover-community/aesop) to search for complete proofs. To do this, simply add `#init_llm_aesop` before using aesop (see [this example](LeanCopilotTests/Aesop.lean)).
### Selecting Premises
-Coming soon.
+Coming soon.*
+## Building LeanCopilot
+
+You don't need to build LeanCopilot directly if you use it in a downstream package. Nevertheless, if you really need to build LeanCopilot, it can be done by `lake build`. However, make sure you have installed these dependencies:
+* CMake >= 3.7
+* A C++17 compatible compiler, e.g., recent versions of GCC or Clang
+
## Questions and Bugs
-* For general questions and discussions, please use [GitHub Discussions](https://github.com/lean-dojo/LeanInfer/discussions).
+* For general questions and discussions, please use [GitHub Discussions](https://github.com/lean-dojo/LeanCopilot/discussions).
* To report a potential bug, please open an issue. In the issue, please include your OS information and the exact steps to reproduce the error. The more details you provide, the better we will be able to help you.
@@ -63,7 +69,7 @@ Coming soon.
## Acknowledgements
* [llmstep](https://github.com/wellecks/llmstep) is another tool providing tactic suggestions using LLMs. We use their frontend for displaying tactics but a different mechanism for running the model.
-* We thank Scott Morrison for suggestions on simplifying LeanInfer's installation and Mac Malone for helping implement it. Both Scott and Mac work for the [Lean FRO](https://lean-fro.org/).
+* We thank Scott Morrison for suggestions on simplifying LeanCopilot's installation and Mac Malone for helping implement it. Both Scott and Mac work for the [Lean FRO](https://lean-fro.org/).
* We thank Jannis Limperg for integrating our LLM-generated tactics into aesop (https://github.com/leanprover-community/aesop/pull/70).
diff --git a/cpp/ct2.cpp b/cpp/ct2.cpp
index 0469a0a..f9fe351 100644
--- a/cpp/ct2.cpp
+++ b/cpp/ct2.cpp
@@ -1,126 +1,132 @@
#include
#include
+#include
+#include
#include
#include
#include
+#include
#include
#include
#include
+#include
-#include "utils.h"
-
-ctranslate2::Translator *p_translator = nullptr;
-ctranslate2::Encoder *p_encoder = nullptr;
-
-const std::string EOS_TOKEN = "";
-const std::vector byt5_vocab = {
- "\u0000", "\u0001", "\u0002", "\u0003", "\u0004", "\u0005", "\u0006",
- "\u0007", "\\b", "\t", "\n", "\u000b", "\\f", "\r",
- "\u000e", "\u000f", "\u0010", "\u0011", "\u0012", "\u0013", "\u0014",
- "\u0015", "\u0016", "\u0017", "\u0018", "\u0019", "\u001a", "\u001b",
- "\u001c", "\u001d", "\u001e", "\u001f", " ", "!", "\"",
- "#", "$", "%", "&", "'", "(", ")",
- "*", "+", ",", "-", ".", "/", "0",
- "1", "2", "3", "4", "5", "6", "7",
- "8", "9", ":", ";", "<", "=", ">",
- "?", "@", "A", "B", "C", "D", "E",
- "F", "G", "H", "I", "J", "K", "L",
- "M", "N", "O", "P", "Q", "R", "S",
- "T", "U", "V", "W", "X", "Y", "Z",
- "[", "\\", "]", "^", "_", "`", "a",
- "b", "c", "d", "e", "f", "g", "h",
- "i", "j", "k", "l", "m", "n", "o",
- "p", "q", "r", "s", "t", "u", "v",
- "w", "x", "y", "z", "{", "|", "}",
- "~", "\u007f", "\u0080", "\u0081", "\u0082", "\u0083", "\u0084",
- "\u0085", "\u0086", "\u0087", "\u0088", "\u0089", "\u008a", "\u008b",
- "\u008c", "\u008d", "\u008e", "\u008f", "\u0090", "\u0091", "\u0092",
- "\u0093", "\u0094", "\u0095", "\u0096", "\u0097", "\u0098", "\u0099",
- "\u009a", "\u009b", "\u009c", "\u009d", "\u009e", "\u009f", "\u00a0",
- "\u00a1", "\u00a2", "\u00a3", "\u00a4", "\u00a5", "\u00a6", "\u00a7",
- "\u00a8", "\u00a9", "\u00aa", "\u00ab", "\u00ac", "\u00ad", "\u00ae",
- "\u00af", "\u00b0", "\u00b1", "\u00b2", "\u00b3", "\u00b4", "\u00b5",
- "\u00b6", "\u00b7", "\u00b8", "\u00b9", "\u00ba", "\u00bb", "\u00bc",
- "\u00bd", "\u00be", "\u00bf", "\u00c0", "\u00c1", "\u00c2", "\u00c3",
- "\u00c4", "\u00c5", "\u00c6", "\u00c7", "\u00c8", "\u00c9", "\u00ca",
- "\u00cb", "\u00cc", "\u00cd", "\u00ce", "\u00cf", "\u00d0", "\u00d1",
- "\u00d2", "\u00d3", "\u00d4", "\u00d5", "\u00d6", "\u00d7", "\u00d8",
- "\u00d9", "\u00da", "\u00db", "\u00dc", "\u00dd", "\u00de", "\u00df",
- "\u00e0", "\u00e1", "\u00e2", "\u00e3", "\u00e4", "\u00e5", "\u00e6",
- "\u00e7", "\u00e8", "\u00e9", "\u00ea", "\u00eb", "\u00ec", "\u00ed",
- "\u00ee", "\u00ef", "\u00f0", "\u00f1", "\u00f2", "\u00f3", "\u00f4",
- "\u00f5", "\u00f6", "\u00f7", "\u00f8", "\u00f9", "\u00fa", "\u00fb",
- "\u00fc", "\u00fd", "\u00fe", "\u00ff"};
-
-std::vector byt5_tokenize(const char *input) {
- std::vector tokens;
- int l = strlen(input);
- for (int i = 0; i < l; i++) {
- tokens.push_back(std::string(1, input[i]));
- }
- return tokens;
+#include "json.hpp"
+#include "npy.hpp"
+
+using json = nlohmann::json;
+
+std::map> generators;
+std::map> encoders;
+
+ctranslate2::StorageView *p_premise_embeddings = nullptr;
+json *p_premise_dictionary = nullptr;
+
+inline bool exists(const std::string &path) {
+ std::ifstream f(path.c_str());
+ return f.good();
}
-extern "C" uint8_t init_ct2_generator(
- b_lean_obj_arg _model_path, // String
- b_lean_obj_arg _device, // String
- b_lean_obj_arg _compute_type, // String
- b_lean_obj_arg _device_index, // Array UInt64
- uint64_t intra_threads) { // UInt64
- const char *model_path = lean_string_cstr(_model_path);
- if (!exists(model_path)) {
- return false;
+inline lean_obj_res lean_mk_pair(lean_obj_arg a, lean_obj_arg b) {
+ lean_object *r = lean_alloc_ctor(0, 2, 0);
+ lean_ctor_set(r, 0, a);
+ lean_ctor_set(r, 1, b);
+ return r;
+}
+
+extern "C" uint8_t cuda_available(b_lean_obj_arg) {
+ return ctranslate2::str_to_device("auto") == ctranslate2::Device::CUDA;
+}
+
+template
+bool is_initialized_aux(const std::string &name);
+
+template<>
+bool is_initialized_aux(const std::string &name) {
+ return generators.find(name) != generators.end();
+}
+
+template<>
+bool is_initialized_aux(const std::string &name) {
+ return encoders.find(name) != encoders.end();
+}
+
+extern "C" uint8_t is_generator_initialized(b_lean_obj_arg _name) {
+ std::string name = std::string(lean_string_cstr(_name));
+ return is_initialized_aux(name);
+}
+
+extern "C" uint8_t is_encoder_initialized(b_lean_obj_arg _name) {
+ std::string name = std::string(lean_string_cstr(_name));
+ return is_initialized_aux(name);
+}
+
+template
+bool init_model(b_lean_obj_arg _name, // String
+ b_lean_obj_arg _model_path, // String
+ b_lean_obj_arg _compute_type, // String
+ b_lean_obj_arg _device, // String
+ b_lean_obj_arg _device_index, // Array UInt64
+ std::map> &models) {
+ std::string name = std::string(lean_string_cstr(_name));
+ if (is_initialized_aux(name)) {
+ throw std::runtime_error(name + " already exists.");
}
- if (p_translator != nullptr) {
- delete p_translator;
- p_translator = nullptr;
+ std::string model_path = std::string(lean_string_cstr(_model_path));
+ if (!exists(model_path)) { // Cannot find the model.
+ return false;
}
ctranslate2::Device device =
ctranslate2::str_to_device(lean_string_cstr(_device));
ctranslate2::ComputeType compute_type =
ctranslate2::str_to_compute_type(lean_string_cstr(_compute_type));
+
std::vector device_indices;
const lean_array_object *p_arr = lean_to_array(_device_index);
for (int i = 0; i < p_arr->m_size; i++) {
device_indices.push_back(lean_unbox_uint64(p_arr->m_data[i]));
}
- ctranslate2::ReplicaPoolConfig config;
- config.num_threads_per_replica = intra_threads;
- p_translator = new ctranslate2::Translator(model_path, device, compute_type,
- device_indices, config);
+ auto p_model =
+ std::make_unique(model_path, device, compute_type, device_indices);
+ models.emplace(name, std::move(p_model));
return true;
}
-inline bool is_ct2_generator_initialized_aux() {
- return p_translator != nullptr;
+extern "C" uint8_t init_generator(
+ b_lean_obj_arg _name, // String
+ b_lean_obj_arg _model_path, // String
+ b_lean_obj_arg _compute_type, // String
+ b_lean_obj_arg _device, // String
+ b_lean_obj_arg _device_index) { // Array UInt64
+ return init_model(_name, _model_path, _compute_type, _device, _device_index,
+ generators);
}
-extern "C" uint8_t is_ct2_generator_initialized(lean_object *) {
- return is_ct2_generator_initialized_aux();
+extern "C" uint8_t init_encoder(b_lean_obj_arg _name, // String
+ b_lean_obj_arg _model_path, // String
+ b_lean_obj_arg _compute_type, // String
+ b_lean_obj_arg _device, // String
+ b_lean_obj_arg _device_index) { // Array UInt64
+ return init_model(_name, _model_path, _compute_type, _device, _device_index,
+ encoders);
}
-std::vector convert_tokens(b_lean_obj_arg _tokens) {
+inline std::vector convert_tokens(b_lean_obj_arg _tokens) {
std::vector tokens;
const lean_array_object *p_arr = lean_to_array(_tokens);
-
for (int i = 0; i < p_arr->m_size; i++) {
- std::string t = lean_string_cstr(p_arr->m_data[i]);
- if (t != EOS_TOKEN && std::find(byt5_vocab.begin(), byt5_vocab.end(), t) ==
- std::end(byt5_vocab)) {
- throw std::invalid_argument("Invalid token: " + t);
- }
- tokens.push_back(t);
+ tokens.emplace_back(lean_string_cstr(p_arr->m_data[i]));
}
-
return tokens;
}
-extern "C" lean_obj_res ct2_generate(
- b_lean_obj_arg _input_tokens, // Array String
+extern "C" lean_obj_res generate(
+ b_lean_obj_arg _name, // String
+ b_lean_obj_arg _input_tokens, // Array String
b_lean_obj_arg _target_prefix_tokens, // Array String
uint64_t num_return_sequences, // UInt64
uint64_t beam_size, // UInt64
@@ -130,8 +136,9 @@ extern "C" lean_obj_res ct2_generate(
double patience, // Float
double temperature) { // Float
// Check the arguments.
- if (!is_ct2_generator_initialized_aux()) {
- throw std::runtime_error("CT2 generator is not initialized.");
+ std::string name = std::string(lean_string_cstr(_name));
+ if (!is_initialized_aux(name)) {
+ throw std::runtime_error(name + " hasn't been initialized.");
}
if (num_return_sequences <= 0) {
throw std::invalid_argument("num_return_sequences must be positive.");
@@ -167,63 +174,45 @@ extern "C" lean_obj_res ct2_generate(
// Get the input tokens ready.
std::vector input_tokens = convert_tokens(_input_tokens);
- assert(input_tokens.back() == EOS_TOKEN);
std::vector target_prefix_tokens =
convert_tokens(_target_prefix_tokens);
// Generate tactics with beam search.
- ctranslate2::TranslationResult results = p_translator->translate_batch(
+ ctranslate2::TranslationResult results = generators.at(name)->translate_batch(
{input_tokens}, {target_prefix_tokens}, opts)[0];
assert(results.hypotheses.size() == num_return_sequences &&
results.scores.size() == num_return_sequences);
// Return the output.
- lean_array_object *output = reinterpret_cast(
- lean_alloc_array(num_return_sequences, num_return_sequences));
+ lean_object *output = lean_mk_empty_array();
for (int i = 0; i < num_return_sequences; i++) {
int l = results.hypotheses[i].size();
- lean_array_object *tokens =
- reinterpret_cast(lean_alloc_array(l, l));
+
+ lean_object *tokens = lean_mk_empty_array();
for (int j = 0; j < l; j++) {
- tokens->m_data[j] = lean_mk_string(results.hypotheses[i][j].c_str());
+ tokens = lean_array_push(
+ tokens, lean_mk_string(results.hypotheses[i][j].c_str()));
}
double score = std::exp(results.scores[i]);
assert(0.0 <= score && score <= 1.0);
- output->m_data[i] = lean_mk_pair(reinterpret_cast(tokens),
- lean_box_float(score));
+ output =
+ lean_array_push(output, lean_mk_pair(tokens, lean_box_float(score)));
}
- return reinterpret_cast(output);
+ return output;
}
-extern "C" uint8_t init_ct2_encoder(b_lean_obj_arg model_path) {
- const char *dir = lean_string_cstr(model_path);
- if (!exists(dir)) {
- return false;
- }
- if (p_encoder != nullptr) {
- delete p_encoder;
+extern "C" lean_obj_res encode(b_lean_obj_arg _name, // String
+ b_lean_obj_arg _input_tokens) { // Array String
+ std::string name = std::string(lean_string_cstr(_name));
+ if (!is_initialized_aux(name)) {
+ throw std::runtime_error(name + " hasn't been initialized.");
}
- p_encoder = new ctranslate2::Encoder(dir, ctranslate2::Device::CPU);
- return true;
-}
-
-inline bool is_ct2_encoder_initialized_aux() { return p_encoder != nullptr; }
-
-extern "C" uint8_t is_ct2_encoder_initialized(lean_object *) {
- return is_ct2_encoder_initialized_aux();
-}
-extern "C" lean_obj_res ct2_encode(b_lean_obj_arg _input_tokens) {
std::vector input_tokens = convert_tokens(_input_tokens);
- // std::vector input_tokens = {"n", " ", ":", " ", "\u00e2",
- // "\u0084", "\u0095", "\n", "\u00e2", "\u008a", "\u00a2", " ", "g", "c", "d",
- // " ", "n", " ", "n", " ", "=", " ", "n", EOS_TOKEN};
- assert(input_tokens.back() == EOS_TOKEN);
-
ctranslate2::EncoderForwardOutput results =
- p_encoder->forward_batch_async({input_tokens}).get();
+ encoders.at(name)->forward_batch_async({input_tokens}).get();
ctranslate2::StorageView hidden_state = results.last_hidden_state;
assert(hidden_state.dim(0) == 1);
@@ -241,3 +230,126 @@ extern "C" lean_obj_res ct2_encode(b_lean_obj_arg _input_tokens) {
return arr;
}
+
+extern "C" uint8_t init_premise_embeddings(
+ b_lean_obj_arg _path, // String
+ b_lean_obj_arg _device) { // String
+ std::string path = std::string(lean_string_cstr(_path));
+ if (!exists(path)) {
+ return false;
+ }
+ if (p_premise_embeddings != nullptr) {
+ delete p_premise_embeddings;
+ }
+
+ // ctranslate2::Device device = ctranslate2::str_to_device(lean_string_cstr(_device));
+ // TODO: We should remove this line when everything can work well on CUDA.
+ ctranslate2::Device device = ctranslate2::Device::CPU;
+
+ const auto &d = npy::read_npy(path);
+ std::vector data = d.data;
+ std::vector shape = d.shape;
+ bool fortran_order = d.fortran_order;
+
+ std::vector data_f;
+ data_f.resize(data.size());
+ std::transform(data.begin(), data.end(), data_f.begin(),
+ [](double d) { return static_cast(d); });
+
+ std::vector shape_i64;
+ shape_i64.resize(shape.size());
+ std::transform(shape.begin(), shape.end(), shape_i64.begin(),
+ [](unsigned long ul) { return static_cast(ul); });
+
+ p_premise_embeddings =
+ new ctranslate2::StorageView(shape_i64, data_f, device);
+ return true;
+}
+
+inline bool premise_embeddings_initialized_aux() {
+ return p_premise_embeddings != nullptr;
+}
+
+extern "C" uint8_t premise_embeddings_initialized(lean_object *) {
+ return premise_embeddings_initialized_aux();
+}
+
+extern "C" uint8_t init_premise_dictionary(b_lean_obj_arg _path) {
+ std::string path = std::string(lean_string_cstr(_path));
+ if (!exists(path)) {
+ return false;
+ }
+ if (p_premise_dictionary != nullptr) {
+ delete p_premise_dictionary;
+ }
+
+ std::ifstream f(path);
+ p_premise_dictionary = new json(json::parse(f));
+
+ return true;
+}
+
+inline bool premise_dictionary_initialized_aux() {
+ return p_premise_dictionary != nullptr;
+}
+
+extern "C" uint8_t premise_dictionary_initialized(lean_object *) {
+ return premise_dictionary_initialized_aux();
+}
+
+extern "C" lean_obj_res retrieve(b_lean_obj_arg _query_emb,
+ uint64_t _k) { // FloatArray
+ // lean_object *arr
+ // assert(p_premise_embeddings && static_cast(p_arr->m_size) ==
+ // p_premise_embeddings->dim(1));
+
+ int64_t d = lean_unbox(lean_float_array_size(_query_emb));
+ std::vector query_emb_data;
+ for (int i = 0; i < d; i++) {
+ query_emb_data.push_back(lean_float_array_uget(_query_emb, i));
+ }
+
+ ctranslate2::Device device = p_premise_embeddings->device();
+ ctranslate2::StorageView query_emb =
+ ctranslate2::StorageView({d, 1}, query_emb_data, device);
+
+ // TODO:
+ ctranslate2::ops::MatMul matmul(false, false, 1.0);
+ long int k = static_cast(_k);
+ ctranslate2::ops::TopK topk(k, -1);
+
+ int num_premises = p_premise_embeddings->dim(0);
+ std::vector probs_shape{num_premises, 1};
+
+ ctranslate2::StorageView probs = ctranslate2::StorageView(
+ probs_shape, ctranslate2::DataType::FLOAT32, device);
+ matmul(*p_premise_embeddings, query_emb, probs);
+ probs.resize({num_premises});
+
+ ctranslate2::StorageView topk_values =
+ ctranslate2::StorageView({k}, ctranslate2::DataType::FLOAT32, device);
+ ctranslate2::StorageView topk_indices =
+ ctranslate2::StorageView({k}, ctranslate2::DataType::INT32, device);
+ topk(probs, topk_values, topk_indices);
+
+ lean_object *output = lean_mk_empty_array();
+ const int *p_topk_indices = topk_indices.data();
+ const float *p_topk_values = topk_values.data();
+
+ for (int i = 0; i < k; i++) {
+ int idx = p_topk_indices[i];
+ assert(0 < idx && idx < num_premises);
+ // [NOTE]: This is where the server crash occurs on CUDA.
+ const std::string this_premise = (*p_premise_dictionary)[std::to_string(idx)]["full_name"];
+ const std::string this_path = (*p_premise_dictionary)[std::to_string(idx)]["path"];
+ const std::string this_code = (*p_premise_dictionary)[std::to_string(idx)]["code"];
+
+ output = lean_array_push(output, lean_mk_pair(
+ lean_mk_string(this_premise.c_str()),
+ lean_mk_pair(lean_mk_string(this_path.c_str()),
+ lean_mk_pair(lean_mk_string(this_code.c_str()),
+ lean_box_float(p_topk_values[i])))));
+ }
+
+ return output;
+}
diff --git a/cpp/json.hpp b/cpp/json.hpp
new file mode 100644
index 0000000..09db5f1
--- /dev/null
+++ b/cpp/json.hpp
@@ -0,0 +1,24757 @@
+// __ _____ _____ _____
+// __| | __| | | | JSON for Modern C++
+// | | |__ | | | | | | version 3.11.2
+// |_____|_____|_____|_|___| https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann
+// SPDX-License-Identifier: MIT
+
+/****************************************************************************\
+ * Note on documentation: The source files contain links to the online *
+ * documentation of the public API at https://json.nlohmann.me. This URL *
+ * contains the most recent documentation and should also be applicable to *
+ * previous versions; documentation for deprecated functions is not *
+ * removed, but marked deprecated. See "Generate documentation" section in *
+ * file docs/README.md. *
+\****************************************************************************/
+
+#ifndef INCLUDE_NLOHMANN_JSON_HPP_
+#define INCLUDE_NLOHMANN_JSON_HPP_
+
+#include // all_of, find, for_each
+#include // nullptr_t, ptrdiff_t, size_t
+#include // hash, less
+#include // initializer_list
+#ifndef JSON_NO_IO
+ #include // istream, ostream
+#endif // JSON_NO_IO
+#include // random_access_iterator_tag
+#include // unique_ptr
+#include // string, stoi, to_string
+#include // declval, forward, move, pair, swap
+#include // vector
+
+// #include
+// __ _____ _____ _____
+// __| | __| | | | JSON for Modern C++
+// | | |__ | | | | | | version 3.11.2
+// |_____|_____|_____|_|___| https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann
+// SPDX-License-Identifier: MIT
+
+
+
+#include
+
+// #include
+// __ _____ _____ _____
+// __| | __| | | | JSON for Modern C++
+// | | |__ | | | | | | version 3.11.2
+// |_____|_____|_____|_|___| https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann
+// SPDX-License-Identifier: MIT
+
+
+
+// This file contains all macro definitions affecting or depending on the ABI
+
+#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK
+ #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH)
+ #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 2
+ #warning "Already included a different version of the library!"
+ #endif
+ #endif
+#endif
+
+#define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum)
+#define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum)
+#define NLOHMANN_JSON_VERSION_PATCH 2 // NOLINT(modernize-macro-to-enum)
+
+#ifndef JSON_DIAGNOSTICS
+ #define JSON_DIAGNOSTICS 0
+#endif
+
+#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
+ #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0
+#endif
+
+#if JSON_DIAGNOSTICS
+ #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag
+#else
+ #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS
+#endif
+
+#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
+ #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp
+#else
+ #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON
+#endif
+
+#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION
+ #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0
+#endif
+
+// Construct the namespace ABI tags component
+#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b
+#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \
+ NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b)
+
+#define NLOHMANN_JSON_ABI_TAGS \
+ NLOHMANN_JSON_ABI_TAGS_CONCAT( \
+ NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \
+ NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON)
+
+// Construct the namespace version component
+#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \
+ _v ## major ## _ ## minor ## _ ## patch
+#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \
+ NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch)
+
+#if NLOHMANN_JSON_NAMESPACE_NO_VERSION
+#define NLOHMANN_JSON_NAMESPACE_VERSION
+#else
+#define NLOHMANN_JSON_NAMESPACE_VERSION \
+ NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \
+ NLOHMANN_JSON_VERSION_MINOR, \
+ NLOHMANN_JSON_VERSION_PATCH)
+#endif
+
+// Combine namespace components
+#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b
+#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \
+ NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b)
+
+#ifndef NLOHMANN_JSON_NAMESPACE
+#define NLOHMANN_JSON_NAMESPACE \
+ nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \
+ NLOHMANN_JSON_ABI_TAGS, \
+ NLOHMANN_JSON_NAMESPACE_VERSION)
+#endif
+
+#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN
+#define NLOHMANN_JSON_NAMESPACE_BEGIN \
+ namespace nlohmann \
+ { \
+ inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \
+ NLOHMANN_JSON_ABI_TAGS, \
+ NLOHMANN_JSON_NAMESPACE_VERSION) \
+ {
+#endif
+
+#ifndef NLOHMANN_JSON_NAMESPACE_END
+#define NLOHMANN_JSON_NAMESPACE_END \
+ } /* namespace (inline namespace) NOLINT(readability/namespace) */ \
+ } // namespace nlohmann
+#endif
+
+// #include
+// __ _____ _____ _____
+// __| | __| | | | JSON for Modern C++
+// | | |__ | | | | | | version 3.11.2
+// |_____|_____|_____|_|___| https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann
+// SPDX-License-Identifier: MIT
+
+
+
+#include // transform
+#include // array
+#include // forward_list
+#include // inserter, front_inserter, end
+#include