Skip to content

Commit

Permalink
Merge pull request #35 from lean-dojo/dev
Browse files Browse the repository at this point in the history
Refactor
  • Loading branch information
Kaiyu Yang authored Dec 6, 2023
2 parents a5cc231 + affc4a5 commit 8166077
Show file tree
Hide file tree
Showing 47 changed files with 27,030 additions and 1,928 deletions.
1 change: 0 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
*.txt
.vscode
build
onnx*
tmp
lake-packages
.lake
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
/.lake

.vscode
*.olean
*.olean
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM ubuntu:latest

WORKDIR /LeanInfer
WORKDIR /LeanCopilot
COPY . .

# Install dependencies.
Expand All @@ -14,5 +14,5 @@ RUN curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -

# Build the Lean project.
RUN lake build
RUN lake script run LeanInfer/download
RUN lake build LeanInferTests
RUN lake script run LeanCopilot/download
RUN lake build LeanCopilotTests
5 changes: 5 additions & 0 deletions LeanCopilot.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import LeanCopilot.Models
import LeanCopilot.Frontend
import LeanCopilot.Options
import LeanCopilot.Tactics
import LeanCopilot.LlmAesop
File renamed without changes.
23 changes: 23 additions & 0 deletions LeanCopilot/LlmAesop.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import LeanCopilot.Tactics
import LeanCopilot.Options
import Aesop

open Lean Lean.Elab.Command

namespace LeanCopilot


def tacGen : Aesop.TacGen := fun (mvarId : MVarId) => do
let state ← ppTacticState [mvarId]
let nm ← SuggestTactics.getGeneratorName
let model ← getGenerator nm
generate model state ""


macro "#init_llm_aesop" : command => `(@[aesop 100%] def tacGen := LeanCopilot.tacGen)


macro "search_proof" : tactic => `(tactic| aesop?)


end LeanCopilot
4 changes: 4 additions & 0 deletions LeanCopilot/Models.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import LeanCopilot.Models.Interface
import LeanCopilot.Models.Defs
import LeanCopilot.Models.Registry
import LeanCopilot.Models.FFI
27 changes: 27 additions & 0 deletions LeanCopilot/Models/Builtin.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ModelCheckpointManager
import LeanCopilot.Models.ByT5

set_option autoImplicit false

namespace LeanCopilot.Builtin


def generator : NativeGenerator := {
url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-tacgen-byt5-small"
tokenizer := ByT5.tokenizer
params := {
numReturnSequences := 32
}
}


def encoder : NativeEncoder := {
url := Url.parse! "https://huggingface.co/kaiyuy/ct2-leandojo-lean4-retriever-byt5-small"
tokenizer := ByT5.tokenizer
}


def premisesUrl := Url.parse! "https://huggingface.co/kaiyuy/premise-embeddings-leandojo-lean4-retriever-byt5-small"


end LeanCopilot.Builtin
33 changes: 20 additions & 13 deletions LeanInfer/Tokenization.lean → LeanCopilot/Models/ByT5.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
namespace LeanInfer
import LeanCopilot.Models.Defs

#eval "\t".length
namespace LeanCopilot.ByT5

def byt5Tokens : Array String := #[

def vocab : Array String := #[
"\u0000",
"\u0001",
"\u0002",
Expand Down Expand Up @@ -263,23 +264,29 @@ def byt5Tokens : Array String := #[


def byteToToken (b : UInt8) : String :=
byt5Tokens.get! b.toNat
vocab.get! b.toNat


def tokenToByte! (t : String) : UInt8 :=
byt5Tokens.findIdx? (· = t) |>.get! |>.toUInt8
vocab.findIdx? (· = t) |>.get! |>.toUInt8


def tokenizeByt5 (text : String) (addEOS : Bool) : List String :=
let tokens := byteToToken <$> text.toUTF8.toList
if addEOS then
tokens ++ ["</s>"]
else
tokens
def tokenize (text : String) : Array String :=
(byteToToken <$> text.toUTF8.toList).toArray


def detokenizeByt5 (tokens : Array String) : String :=
def detokenize (tokens : Array String) : String :=
String.fromUTF8Unchecked ⟨tokens.map tokenToByte!⟩


end LeanInfer
def eosToken := "</s>"


def tokenizer : Tokenizer := {
tokenize := tokenize,
detokenize := detokenize,
eosToken := eosToken
}


end LeanCopilot.ByT5
170 changes: 170 additions & 0 deletions LeanCopilot/Models/Defs.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import Lean
import ModelCheckpointManager
import LeanCopilot.Models.Interface

set_option autoImplicit false

open Lean
open System (FilePath)

namespace LeanCopilot


inductive Device where
| cpu
| cuda
| auto
deriving Repr


instance : Inhabited Device where
default := .auto


def Device.toString : Device → String
| Device.cpu => "cpu"
| Device.cuda => "cuda"
| Device.auto => "auto"

instance : ToString Device := ⟨Device.toString⟩


inductive ComputeType where
| default
| auto
| int8
| int8_float32
| int8_float16
| int8_bfloat16
| int16
| float16
| bfloat16
| float32
deriving Repr


def ComputeType.toString : ComputeType → String
| ComputeType.default => "default"
| ComputeType.auto => "auto"
| ComputeType.int8 => "int8"
| ComputeType.int8_float32 => "int8_float32"
| ComputeType.int8_float16 => "int8_float16"
| ComputeType.int8_bfloat16 => "int8_bfloat16"
| ComputeType.int16 => "int16"
| ComputeType.float16 => "float16"
| ComputeType.bfloat16 => "bfloat16"
| ComputeType.float32 => "float32"


instance : ToString ComputeType := ⟨ComputeType.toString⟩


structure Tokenizer where
tokenize : String → Array String
detokenize : Array String → String
eosToken : String


structure NativeModel where
url : Url
device : Device := .auto
deviceIndex : Array UInt64 := #[0]
computeType : ComputeType := .default
tokenizer : Tokenizer


def NativeModel.name (model : NativeModel) : String := model.url.name!


def NativeModel.path (model : NativeModel) : IO FilePath :=
getModelDir model.url


structure BeamSearchParams where
numReturnSequences : UInt64
beamSize : UInt64 := numReturnSequences
minLength : UInt64 := 1
maxLength : UInt64 := 1024
lengthPenalty : Float := 0.0
patience : Float := 2.0
temperature : Float := 1.0
deriving Repr


structure NativeGenerator extends NativeModel where
params : BeamSearchParams


structure NativeEncoder extends NativeModel


structure ExternalModel where
name : String
host : String := "localhost"
port : UInt16 := 23333
deriving Inhabited, Repr


structure ExternalGenerator extends ExternalModel
deriving Repr


structure ExternalRequest where
name : String
input : String
«prefix» : String
deriving ToJson


structure ExternalResponse where
outputs : Array (String × Float)
deriving FromJson


def ExternalGenerator.generate (model : ExternalGenerator) (input : String) (targetPrefix : String) : IO $ Array (String × Float) := do
let url := s!"http://{model.host}:{model.port}/generate"
let req : ExternalRequest := {
name := model.name,
input := input,
«prefix» := targetPrefix
}
let reqStr := (toJson req).pretty 99999999999999999
let out ← IO.Process.run {
cmd := "curl"
args := #["-X", "POST", url, "-H", "accept: application/json", "-H", "Content-Type: application/json", "-d", reqStr]
}

let some json := Json.parse out |>.toOption | throw $ IO.userError "Failed to parse response"
let some res := (fromJson? json : Except String ExternalResponse) |>.toOption | throw $ IO.userError "Failed to parse response"
return res.outputs


instance : TextToText ExternalGenerator := ⟨ExternalGenerator.generate⟩


structure ExternalEncoder extends ExternalModel
deriving Repr


def ExternalEncoder.encode (model : ExternalEncoder) (input : String) : IO FloatArray := do
return FloatArray.mk #[0.0]


instance : TextToVec ExternalEncoder := ⟨ExternalEncoder.encode⟩


structure GenericGenerator where
generate : String → String → IO (Array (String × Float))


instance : TextToText GenericGenerator := ⟨GenericGenerator.generate⟩


structure GenericEncoder where
encode : String → IO FloatArray


instance : TextToVec GenericEncoder := ⟨GenericEncoder.encode⟩


end LeanCopilot
Loading

0 comments on commit 8166077

Please sign in to comment.