diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 4484b293..15ed2764 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -223,8 +223,8 @@ jobs: -batch_size 10 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk "-1" \ - -random_sampling_temp 0.0001 \ + -top_k "-1" \ + -temperature 0.0001 \ -tgt eole/tests/data/morph/tgt.valid \ -out /tmp/trans diff eole/tests/data/morph/tgt.valid /tmp/trans && rm /tmp/trans @@ -253,8 +253,8 @@ jobs: -verbose -batch_size 1 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk -1 \ - -random_sampling_temp 0.0001 \ + -top_k -1 \ + -temperature 0.0001 \ -ban_unk_token \ -length_penalty none \ -out /tmp/gen @@ -266,9 +266,9 @@ jobs: -verbose -batch_size 1 \ -beam_size 1 \ -seed 3 \ - -random_sampling_topk -1 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k -1 \ + -top_p 0.95 \ + -temperature 1 \ -ban_unk_token \ -length_penalty none \ -out /tmp/gen @@ -280,9 +280,9 @@ jobs: -verbose -batch_size 1 \ -beam_size 10 \ -seed 2 \ - -random_sampling_topk 50 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k 50 \ + -top_p 0.95 \ + -temperature 1 \ -length_penalty avg \ -ban_unk_token \ -min_length 5 \ diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index 6fe01df7..92461278 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -788,6 +788,8 @@ def get_weight(checkpoint, tensor_name): if ( tokenizer_model is not None ): # sentencepiece mode (might be good to check it's a SP model) + src_subword_type = "sentencepiece" + tokenizer_basename = os.path.basename(tokenizer_model) tokenizer = Tokenizer(model_path=tokenizer_model) vocab = tokenizer.vocab # vocab[3] = DefaultTokens.PAD @@ -806,13 +808,20 @@ def get_weight(checkpoint, tensor_name): special_tokens=["", "", ""], ) else: # # BPE mode - we leverage the HF tokenizer.json info + src_subword_type = "bpe" with open(tokenizer_json, encoding="utf-8") as f: data = json.load(f) - vocab = [ - tok if tok != "Ā" else DefaultTokens.PAD - # "Ā" is '\x00' in unicode (cf tokenize.py gpt2 mapping) - for tok in data["model"]["vocab"] - ] + # gpt2_pretok + gpt2_pretok = False + pretokenizers = data.get("pre_tokenizer", {}).get("pretokenizers", [{}]) + for pretokenizer in pretokenizers: + if pretokenizer.get("type", None) == "ByteLevel": + gpt2_pretok = True + vocab = [ + tok if tok != "Ā" else DefaultTokens.PAD + # "Ā" is '\x00' in unicode (cf tokenize.py gpt2 mapping) + for tok in data["model"]["vocab"] + ] voc_size = len(vocab) if vocab_size > voc_size: for i in range(vocab_size - voc_size): @@ -834,8 +843,10 @@ def get_weight(checkpoint, tensor_name): src_vocab = pyonmttok.build_vocab_from_tokens(vocab) + tokenizer_basename = "bpe.model" + with open( - os.path.join(directory_path, "bpe.model"), "w", encoding="utf-8" + os.path.join(directory_path, tokenizer_basename), "w", encoding="utf-8" ) as bpemodel: bpemodel.write("v3;false;false;false;Ġ;Ġ\n") for merge in data["model"]["merges"]: @@ -928,3 +939,21 @@ def get_weight(checkpoint, tensor_name): os.path.join(directory_path, "config.json"), "w", encoding="utf-8" ) as f: json.dump(config_dict, f, indent=2, ensure_ascii=False) + + inference_dict = { + "transforms": ["onmt_tokenize"], + "transforms_configs": { + "onmt_tokenize": { + "src_subword_type": src_subword_type, + "src_subword_model": os.path.join( + "${MODEL_PATH}", tokenizer_basename + ), + "gpt2_pretok": gpt2_pretok, + } + }, + } + + with open( + os.path.join(directory_path, "inference.json"), "w", encoding="utf-8" + ) as f: + json.dump(inference_dict, f, indent=2, ensure_ascii=False) diff --git a/eole/bin/run/serve.py b/eole/bin/run/serve.py new file mode 100644 index 00000000..3db91589 --- /dev/null +++ b/eole/bin/run/serve.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python + +import os +import time +import gc +import json +import yaml + +from typing import List, Union + +import torch +import uvicorn + +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse +from pydantic import BaseModel, Field + +from eole.inference_engine import InferenceEnginePY +from eole.config.run import PredictConfig +from eole.config.inference import DecodingConfig +from eole.bin import register_bin, BaseBin +from eole.utils.logging import logger + +STATUS_OK = "ok" +STATUS_ERROR = "error" + + +class TextRequest(DecodingConfig): + """ + Standard text "completion" request + (as well as encoder/decoder models e.g. translation). + """ + + model: int | str = Field(description="Model identifier from server configuration.") + inputs: Union[str, List[str]] = Field( + description="List of inputs to run inference on. " + "A single string will be automatically cast to a single item list." + ) + + +class TextResponse(BaseModel): + """ + Response of TextRequest. + """ + + predictions: List[List[str]] = Field( + description="List of prediction(s) for each input(s)." + ) + scores: List[List[float]] = Field( + description="Pred scores from the model for each prediction." + ) + + +# class ChatRequest(DecodingConfig): +# model: str +# messages: List[dict] + + +# class ChatResponse(BaseModel): +# choices: List[dict] + + +class Server(object): + """ + Main server class to manage configuration, models and corresponding constraints. + """ + + def __init__(self): + self.start_time = time.time() + self.models = {} + self.models_root = None + + def start(self, server_config_path): + with open(server_config_path) as f: + server_config = yaml.safe_load(f) + self.models_root = server_config["models_root"] + for model in server_config["models"]: + # instantiate models + # add some safeguards here, download from HF, etc. + model_id = model["id"] + model_path = model["path"] + self.models[model_id] = Model( + model_path=model_path, + models_root=self.models_root, + model_type=model.get("model_type", "default"), + ) + if model.get("preload", False): + self.models[model_id].start_engine() + + def available_models(self): + models = [] + for model_id, model in self.models.items(): + models.append({"id": model_id}) + return models + + +class Model(object): + def __init__( + self, model_path=None, preload=False, models_root=None, model_type=False + ): + self.loaded = False + self.engine = None + self.preload = preload + self.models_root = models_root + self.model_path = model_path + self.local_path = None + self.model_type = model_type + + def get_config(self): + # look for inference config supposedly in model_dir/inference.json + config_path = os.path.join(self.local_path, "inference.json") + if os.path.exists(config_path): + with open(config_path) as f: + os.environ["MODEL_PATH"] = self.local_path + config_dict = json.loads(os.path.expandvars(f.read())) + + self.config = PredictConfig( + src="dummy", + model_path=self.local_path, + # TODO improve this + gpu_ranks=[0], + world_size=1, + precision="fp16", + **config_dict, + ) + + def override_opts(self): + """ + Potentially override some opts from a config file? + """ + pass + + def maybe_retrieve_model(self): + from huggingface_hub import HfApi, snapshot_download + + hf_api = HfApi() + try: + hf_api.model_info(self.model_path) + except Exception: + self.local_path = self.model_path + else: + self.local_path = os.path.join(self.models_root, self.model_path) + logger.info( + f"Downloading {self.model_path} from huggingface, " + f"to local directory {self.local_path}" + ) + snapshot_download(repo_id=self.model_path, local_dir=self.local_path) + + def start_engine(self): + """ + We might want to call this "load"... + """ + + self.maybe_retrieve_model() + self.get_config() + self.engine = InferenceEnginePY(self.config) + self.loaded = True + + def unload(self): + """ + stop_engine if start_engine is not renamed... + Not super clean, we might want to do better some day... + """ + del self.engine + gc.collect() + torch.cuda.empty_cache() + self.engine = None + self.loaded = False + + def infer(self, inputs, settings={}): + if type(inputs) == str: + inputs = [inputs] + if not (self.loaded): + self.start_engine() + scores, _, preds = self.engine.infer_list(inputs, settings=settings) + return scores, preds + + +def create_app(config_file): + app = FastAPI() + + server = Server() + server.start(config_file) + + @app.get("/") + def root(request: Request): + html_content = f""" + + + Eole Server + + +

Eole Server

+

Probably not what you're looking for.

+

API docs --> {request.url}docs.

+ + + """ + return HTMLResponse(content=html_content, status_code=200) + + @app.get("/models") + def models(): + """ + Return available models currently exposed. + """ + models = server.available_models() + out = {"models": models} + return out + + @app.post("/unload_model") + def unload_model(model_id): + server.models[model_id].unload() + + @app.get("/health") + def health(): + out = {} + out["status"] = STATUS_OK + return out + + @app.post("/infer") + def infer(request: TextRequest, response_model=TextResponse): + if isinstance(request.inputs, str): + request.inputs = [request.inputs] + model_id = request.model + inputs = request.inputs + # automatically grab anything that is not model/inputs + # (we could probably rely on pydantic model once properly implemented) + non_settings_keys = ["inputs", "model"] + settings = { + k: v for k, v in request.model_dump().items() if k not in non_settings_keys + } + scores, preds = server.models[model_id].infer(inputs, settings=settings) + # returned scores are tensors which we need to cast + scores = [[score.item() for score in score_list] for score_list in scores] + response = {"predictions": preds, "scores": scores} + return response + + # @app.post("/openai/chat/completions") + # def openai_chat(request: ChatRequest): + # """ + # Simulate an OpenAI Request. + # The idea is to make this a viable alternative as a drop-in + # replacement for OpenAI or other LLM stacks. + # The actual chat -> prompt conversion might depend on the model, + # and could be defined in the inference.json config for instance. + # """ + # pass + + return app + + +@register_bin(name="serve") +class Serve(BaseBin): + @classmethod + def add_args(cls, parser): + parser.add_argument( + "--config", + "-config", + "-c", + default="./server_conf.yaml", + help="Path of server YAML config file.", + ) + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default="5000") + + @classmethod + def run(cls, args): + app = create_app(args.config) + uvicorn.run(app=app, host=args.host, port=args.port, log_level="info") diff --git a/eole/bin/tools/LM_scoring.py b/eole/bin/tools/LM_scoring.py index 65b53973..81c00dac 100644 --- a/eole/bin/tools/LM_scoring.py +++ b/eole/bin/tools/LM_scoring.py @@ -70,9 +70,16 @@ def run(cls, args): set_random_seed(config.seed, False) ppl_file = codecs.open(config.output + ".ppl", "w+", "utf-8") + # no tensor_parallel support device = ( - torch.device("cuda", config.gpu) if config.gpu > -1 else torch.device("cpu") + torch.device("cuda", config.gpu_ranks[0]) + if len(config.gpu_ranks) > 0 + else torch.device("cpu") ) + if len(config.gpu_ranks) > 1: + logger.warning( + f"gpu_ranks is {str(config.gpu_ranks)} but only the first one will be used." + ) vocabs, model, model_opt = config.model.model_class.load_test_model(config) padding_idx = vocabs["tgt"][DefaultTokens.PAD] diff --git a/eole/config/__init__.py b/eole/config/__init__.py index b481a876..c42f2f0c 100644 --- a/eole/config/__init__.py +++ b/eole/config/__init__.py @@ -31,7 +31,9 @@ def recursive_model_fields_set(model): else: field_value = getattr(model, field, None) if isinstance(field_value, Config) or isinstance(field_value, dict): - fields[field] = recursive_model_fields_set(field_value) + _fields = recursive_model_fields_set(field_value) + if _fields != {}: + fields[field] = _fields else: fields[field] = field_value return fields diff --git a/eole/config/inference.py b/eole/config/inference.py index e5409180..3beb5e0f 100644 --- a/eole/config/inference.py +++ b/eole/config/inference.py @@ -10,13 +10,13 @@ class DecodingConfig(Config): ratio: float = Field( default=-0.0, description="Ratio based beam stop condition." ) # is the minus sign useful here? - random_sampling_topk: int = Field( + top_k: int = Field( default=0, description="Set this to -1 to do random sampling from full distribution. " "Set this to value k>1 to do random sampling restricted to " "the k most likely next tokens. Set this to 1 to use argmax.", ) - random_sampling_topp: float = Field( + top_p: float = Field( default=0.0, description="Probability for top-p/nucleus sampling. " "Restrict tokens to the most likely until the cumulated probability " @@ -24,7 +24,7 @@ class DecodingConfig(Config): ge=0.0, lte=1.0, ) - random_sampling_temp: float = Field( + temperature: float = Field( default=1.0, description="If doing random sampling, divide the logits by this " "before computing softmax during decoding.", @@ -44,14 +44,15 @@ class DecodingConfig(Config): default=False, description="Apply coverage penalty at every decoding step. Helpful for summary penalty.", ) - min_length: int = Field(default=0, description="Minimum prediction length.") + min_length: int = Field(default=0, description="Minimum prediction length.", ge=0) max_length: int = Field(default=250, description="Maximum prediction length.") max_length_ratio: float = Field( default=2, description="Maximum prediction length ratio. For European languages, " "2 is large enough, for target Asian charageters, " "need to increase to 2-3, for special languages (Burmese, Amharic) to 10.", - ) + ge=1, + ) # we might want to validate this against min_length block_ngram_repeat: int = Field( default=0, description="Block repetition of ngrams during decoding." ) @@ -132,9 +133,6 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig) batch_type: Literal["sents", "tokens"] = Field( default="sents", description="Batch grouping for batch size." ) - gpu: int = Field( - default=-1, description="Device to run on. -1 will default to CPU." - ) precision: Literal["", "fp32", "fp16", "int8"] = Field( default="", description="Precision to run inference. " diff --git a/eole/inference_engine.py b/eole/inference_engine.py index 0abe7ff3..7b39568c 100755 --- a/eole/inference_engine.py +++ b/eole/inference_engine.py @@ -37,7 +37,7 @@ def infer_file(self): scores, estims, preds = self.infer_file_parallel() return scores, estims, preds - def infer_list(self, src): + def infer_list(self, src, settings={}): """List of strings inference `src`""" if self.config.world_size <= 1: infer_iter = build_dynamic_dataset_iter( @@ -48,9 +48,9 @@ def infer_list(self, src): src=src, device_id=self.device_id, ) - scores, estims, preds = self._predict(infer_iter) + scores, estims, preds = self._predict(infer_iter, settings=settings) else: - scores, estims, preds = self.infer_list_parallel(src) + scores, estims, preds = self.infer_list_parallel(src, settings=settings) return scores, estims, preds def infer_file_parallel(self): @@ -59,7 +59,7 @@ def infer_file_parallel(self): "The inference in mulitprocessing with partitioned models is not implemented." ) - def infer_list_parallel(self, src): + def infer_list_parallel(self, src, settings={}): """The inference in mulitprocessing with partitioned models.""" raise NotImplementedError( "The inference in mulitprocessing with partitioned models is not implemented." @@ -125,11 +125,13 @@ def __init__(self, config): self.error_queue = mp.SimpleQueue() self.error_handler = ErrorHandler(self.error_queue) self.queue_instruct = [] + self.queue_settings = [] self.queue_result = [] self.procs = [] for device_id in range(config.world_size): self.queue_instruct.append(mp.Queue()) + self.queue_settings.append(mp.Queue()) self.queue_result.append(mp.Queue()) self.procs.append( mp.Process( @@ -147,7 +149,10 @@ def __init__(self, config): self.procs[device_id].start() self.error_handler.add_child(self.procs[device_id].pid) else: - self.device_id = config.gpu + if len(config.gpu_ranks) > 0: + self.device_id = config.gpu_ranks[0] + else: + self.device_id = -1 # cpu self.predictor = build_predictor( config, self.device_id, logger=self.logger, report_score=True ) @@ -156,7 +161,8 @@ def __init__(self, config): self.transforms = make_transforms(config, self.transforms_cls, self.vocabs) self.transform_pipe = TransformPipe.build_from(self.transforms.values()) - def _predict(self, infer_iter): + def _predict(self, infer_iter, settings={}): + self.predictor.update_settings(**settings) scores, estims, preds = self.predictor._predict( infer_iter, infer_iter.transforms, @@ -188,20 +194,23 @@ def score_file_parallel(self): score_results.append(self.queue_result[device_id].get()) return score_results[0] - def infer_file_parallel(self): + def infer_file_parallel(self, settings={}): assert self.config.world_size > 1, "World size must be greater than 1." for device_id in range(self.config.world_size): self.queue_instruct[device_id].put(("infer_file", self.config)) + # not sure if we want a separate queue or additional info in queue_instruct + self.queue_settings[device_id].put(settings) scores, preds = [], [] for device_id in range(self.config.world_size): scores.append(self.queue_result[device_id].get()) preds.append(self.queue_result[device_id].get()) return scores[0], preds[0] - def infer_list_parallel(self, src): + def infer_list_parallel(self, src, settings={}): assert self.config.world_size > 1, "World size must be greater than 1." for device_id in range(self.config.world_size): self.queue_instruct[device_id].put(("infer_list", src)) + self.queue_settings[device_id].put(settings) scores, preds = [], [] for device_id in range(self.config.world_size): scores.append(self.queue_result[device_id].get()) @@ -232,11 +241,12 @@ def __init__(self, config, model_task=None): ), "A model_task kwarg must be passed for CT2 models." self.logger = init_logger(config.log_file) assert self.config.world_size <= 1, "World size must be less than 1." - self.device_id = config.gpu if config.world_size == 1: + self.device_id = config.gpu_ranks[0] self.device_index = config.gpu_ranks self.device = "cuda" else: + self.device_id = -1 self.device_index = 0 self.device = "cpu" self.transforms_cls = get_transforms_cls(self.config._all_transform) @@ -289,9 +299,9 @@ def predict_batch(self, batch, config): max_length=config.max_length, return_scores=True, include_prompt_in_result=False, - sampling_topk=config.random_sampling_topk, - sampling_topp=config.random_sampling_topp, - sampling_temperature=config.random_sampling_temp, + sampling_topk=config.top_k, + sampling_topp=config.top_p, + sampling_temperature=config.temperature, ) preds = [ [self.transforms.apply_reverse(tokens) for tokens in out.sequences] @@ -307,9 +317,9 @@ def predict_batch(self, batch, config): num_hypotheses=config.n_best, max_decoding_length=config.max_length, return_scores=True, - sampling_topk=config.random_sampling_topk, - sampling_topp=config.random_sampling_topp, - sampling_temperature=config.random_sampling_temp, + sampling_topk=config.top_k, + sampling_topp=config.top_p, + sampling_temperature=config.temperature, ) preds = [ [self.transforms.apply_reverse(tokens) for tokens in out.hypotheses] @@ -319,7 +329,8 @@ def predict_batch(self, batch, config): return scores, None, preds - def _predict(self, infer_iter): + def _predict(self, infer_iter, settings={}): + # TODO: convert settings to CT2 naming scores = [] preds = [] for batch, bucket_idx in infer_iter: diff --git a/eole/models/model.py b/eole/models/model.py index c0a9e84d..0167ba93 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -358,8 +358,6 @@ def inference_logic(self, checkpoint, running_config, vocabs, device_id=None): if use_gpu(running_config): if len(running_config.gpu_ranks) > 0: device_id = running_config.gpu_ranks[0] - elif running_config.gpu > -1: - device_id = running_config.gpu device = torch.device("cuda", device_id) else: device = torch.device("cpu") diff --git a/eole/models/model_saver.py b/eole/models/model_saver.py index dd4b0bf9..b07d390d 100644 --- a/eole/models/model_saver.py +++ b/eole/models/model_saver.py @@ -261,6 +261,15 @@ def _save_config(self): with open(config_path, "w") as f: json.dump(config_data, f, indent=2, ensure_ascii=False) self._make_symlink("config.json") + # save transforms related config for inference + inference_keys = ["transforms", "transforms_configs"] + inference_data = { + k: config_data[k] for k in inference_keys if k in config_data.keys() + } + inference_path = os.path.join(self.model_path, self.step_dir, "inference.json") + with open(inference_path, "w") as f: + json.dump(inference_data, f, indent=2, ensure_ascii=False) + self._make_symlink("inference.json") def _save_transforms_artifacts(self): if self.transforms is not None: diff --git a/eole/predict/__init__.py b/eole/predict/__init__.py index 7e78e5ea..c09b5c98 100644 --- a/eole/predict/__init__.py +++ b/eole/predict/__init__.py @@ -58,6 +58,7 @@ def build_predictor(config, device_id=0, report_score=True, logger=None, out_fil vocabs, config, model_config, + device_id=device_id, global_scorer=scorer, out_file=out_file, report_align=config.report_align, diff --git a/eole/predict/encoder.py b/eole/predict/encoder.py index 6e221df6..69370b15 100644 --- a/eole/predict/encoder.py +++ b/eole/predict/encoder.py @@ -24,7 +24,7 @@ def predict_batch(self, batch, attn_debug): else: max_length = self.max_length with torch.no_grad(): - if self.sample_from_topk != 0 or self.sample_from_topp != 0: + if self.top_k != 0 or self.top_p != 0: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -39,9 +39,9 @@ def predict_batch(self, batch, attn_debug): block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, - sampling_temp=self.random_sampling_temp, - keep_topk=self.sample_from_topk, - keep_topp=self.sample_from_topp, + sampling_temp=self.temperature, + top_k=self.top_k, + top_p=self.top_p, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) diff --git a/eole/predict/generator.py b/eole/predict/generator.py index a43018b5..fa0efb31 100644 --- a/eole/predict/generator.py +++ b/eole/predict/generator.py @@ -25,7 +25,7 @@ def predict_batch(self, batch, attn_debug, scoring=False): """Predict a batch of sentences.""" max_length = 0 if scoring else self.max_length with torch.no_grad(): - if self.sample_from_topk != 0 or self.sample_from_topp != 0: + if self.top_k != 0 or self.top_p != 0: decode_strategy = GreedySearchLM( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -40,9 +40,9 @@ def predict_batch(self, batch, attn_debug, scoring=False): block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, - sampling_temp=self.random_sampling_temp, - keep_topk=self.sample_from_topk, - keep_topp=self.sample_from_topp, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) diff --git a/eole/predict/greedy_search.py b/eole/predict/greedy_search.py index 562c0ca0..02a8e870 100644 --- a/eole/predict/greedy_search.py +++ b/eole/predict/greedy_search.py @@ -3,11 +3,11 @@ from eole.predict.decode_strategy import DecodeStrategy -def sample_topp(logits, keep_topp): +def sample_topp(logits, top_p): sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=1) cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_keep = cumulative_probs.lt(keep_topp) + sorted_indices_to_keep = cumulative_probs.lt(top_p) # keep indices until overflowing p cumsum_mask = sorted_indices_to_keep.cumsum(dim=1) @@ -25,8 +25,8 @@ def sample_topp(logits, keep_topp): return logits.masked_fill(~keep_indices, -10000) -def sample_topk(logits, keep_topk): - top_values, _ = torch.topk(logits, keep_topk, dim=1) +def sample_topk(logits, top_k): + top_values, _ = torch.topk(logits, top_k, dim=1) kth_best = top_values[:, -1].view([-1, 1]) kth_best = kth_best.repeat([1, logits.shape[1]]).float() @@ -36,11 +36,11 @@ def sample_topk(logits, keep_topk): return logits.masked_fill(ignore, -10000) -def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): +def sample_with_temperature(logits, temperature, top_k, top_p): """Select next tokens randomly from the top k possible next tokens. - Samples from a categorical distribution over the ``keep_topk`` words using - the category probabilities ``logits / sampling_temp``. + Samples from a categorical distribution over the ``top_k`` words using + the category probabilities ``logits / temperature``. Args: logits (FloatTensor): Shaped ``(batch_size, vocab_size)``. @@ -48,13 +48,13 @@ def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): (The distribution actually uses the log-probabilities ``logits - logits.logsumexp(-1)``, which equals the logits if they are log-probabilities summing to 1.) - sampling_temp (float): Used to scale down logits. The higher the + temperature (float): Used to scale down logits. The higher the value, the more likely it is that a non-max word will be sampled. - keep_topk (int): This many words could potentially be chosen. The + top_k (int): This many words could potentially be chosen. The other logits are set to have probability 0. - keep_topp (float): Keep most likely words until the cumulated - probability is greater than p. If used with keep_topk: both + top_p (float): Keep most likely words until the cumulated + probability is greater than p. If used with top_k: both conditions will be applied Returns: @@ -63,21 +63,21 @@ def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): * topk_ids: Shaped ``(batch_size, 1)``. These are the sampled word indices in the output vocab. * topk_scores: Shaped ``(batch_size, 1)``. These - are essentially ``(logits / sampling_temp)[topk_ids]``. + are essentially ``(logits / temperature)[topk_ids]``. """ - if sampling_temp == 0.0 or keep_topk == 1: + if temperature == 0.0 or top_k == 1: # For temp=0.0, take the argmax to avoid divide-by-zero errors. - # keep_topk=1 is also equivalent to argmax. + # top_k=1 is also equivalent to argmax. topk_scores, topk_ids = logits.topk(1, dim=-1) - if sampling_temp > 0: - topk_scores /= sampling_temp + if temperature > 0: + topk_scores /= temperature else: - logits = torch.div(logits, sampling_temp) - if keep_topp > 0: - logits = sample_topp(logits, keep_topp) - if keep_topk > 0: - logits = sample_topk(logits, keep_topk) + logits = torch.div(logits, temperature) + if top_p > 0: + logits = sample_topp(logits, top_p) + if top_k > 0: + logits = sample_topk(logits, top_k) dist = torch.distributions.Categorical(logits=logits) topk_ids = dist.sample().view(-1, 1) topk_scores = logits.gather(dim=1, index=topk_ids) @@ -108,11 +108,11 @@ class GreedySearch(DecodeStrategy): exclusion_tokens (set[int]): See base. return_attention (bool): See base. max_length (int): See base. - sampling_temp (float): See + temperature (float): See :func:`~eole.predict.greedy_search.sample_with_temperature()`. - keep_topk (int): See + top_k (int): See :func:`~eole.predict.greedy_search.sample_with_temperature()`. - keep_topp (float): See + top_p (float): See :func:`~eole.predict.greedy_search.sample_with_temperature()`. beam_size (int): Number of beams to use. """ @@ -132,9 +132,9 @@ def __init__( exclusion_tokens, return_attention, max_length, - sampling_temp, - keep_topk, - keep_topp, + temperature, + top_k, + top_p, beam_size, ban_unk_token, add_estimator=False, @@ -156,9 +156,9 @@ def __init__( ban_unk_token, add_estimator, ) - self.sampling_temp = sampling_temp - self.keep_topk = keep_topk - self.keep_topp = keep_topp + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p self.topk_scores = None self.beam_size = beam_size self.n_best = n_best @@ -200,7 +200,7 @@ def _pick(self, log_probs): # maybe fix some prediction at this step by modifying log_probs log_probs = self.target_prefixing(log_probs) topk_ids, topk_scores = sample_with_temperature( - log_probs, self.sampling_temp, self.keep_topk, self.keep_topp + log_probs, self.temperature, self.top_k, self.top_p ) return topk_ids, topk_scores diff --git a/eole/predict/inference.py b/eole/predict/inference.py index f37b7463..4d044580 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -28,9 +28,11 @@ class Inference(object): max_length (int): See :class:`eole.predict.decode_strategy.DecodeStrategy`. beam_size (int): Number of beams. - random_sampling_topk (int): See + top_p (float): See :class:`eole.predict.greedy_search.GreedySearch`. - random_sampling_temp (float): See + top_k (int): See + :class:`eole.predict.greedy_search.GreedySearch`. + temperature (float): See :class:`eole.predict.greedy_search.GreedySearch`. stepwise_penalty (bool): Whether coverage penalty is applied every step or not. @@ -62,9 +64,9 @@ def __init__( max_length_ratio=1.5, ratio=0.0, beam_size=30, - random_sampling_topk=0, - random_sampling_topp=0.0, - random_sampling_temp=1.0, + top_k=0, + top_p=0.0, + temperature=1.0, stepwise_penalty=None, dump_beam=False, block_ngram_repeat=0, @@ -109,9 +111,9 @@ def __init__( self.max_length_ratio = max_length_ratio self.beam_size = beam_size - self.random_sampling_temp = random_sampling_temp - self.sample_from_topk = random_sampling_topk - self.sample_from_topp = random_sampling_topp + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p self.min_length = min_length self.ban_unk_token = ban_unk_token @@ -166,6 +168,7 @@ def from_config( vocabs, config, # running/predict config model_config, + device_id=0, global_scorer=None, out_file=None, report_align=False, @@ -194,16 +197,16 @@ def from_config( return cls( model, vocabs, - gpu=config.gpu, + gpu=device_id, n_best=config.n_best, min_length=config.min_length, max_length=config.max_length, max_length_ratio=config.max_length_ratio, ratio=config.ratio, beam_size=config.beam_size, - random_sampling_topk=config.random_sampling_topk, - random_sampling_topp=config.random_sampling_topp, - random_sampling_temp=config.random_sampling_temp, + top_k=config.top_k, + top_p=config.top_p, + temperature=config.temperature, stepwise_penalty=config.stepwise_penalty, dump_beam=config.dump_beam, block_ngram_repeat=config.block_ngram_repeat, @@ -243,6 +246,12 @@ def _gold_score(self, batch, enc_out, src_len, enc_final_hs, batch_size, src): glp = None return gs, glp + def update_settings(self, **kwargs): + # we probably would need some validation at some point + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + def _predict( self, infer_iter, diff --git a/eole/predict/translator.py b/eole/predict/translator.py index b36acf65..439ce08b 100644 --- a/eole/predict/translator.py +++ b/eole/predict/translator.py @@ -74,7 +74,7 @@ def predict_batch(self, batch, attn_debug): else: max_length = self.max_length with torch.no_grad(): - if self.sample_from_topk != 0 or self.sample_from_topp != 0: + if self.top_k != 0 or self.top_p != 0: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, @@ -89,9 +89,9 @@ def predict_batch(self, batch, attn_debug): block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, - sampling_temp=self.random_sampling_temp, - keep_topk=self.sample_from_topk, - keep_topp=self.sample_from_topp, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) diff --git a/eole/tests/data/inference-engine_py.yaml b/eole/tests/data/inference-engine_py.yaml index 89db33dd..f738c678 100644 --- a/eole/tests/data/inference-engine_py.yaml +++ b/eole/tests/data/inference-engine_py.yaml @@ -2,9 +2,9 @@ world_size: 0 max_length: 512 batch_type: sents batch_size: 100 -random_sampling_topk: 40 -random_sampling_topp: 0.75 -random_sampling_temp: 0.1 +top_k: 40 +top_p: 0.75 +temperature: 0.1 beam_size: 2 n_best: 2 src: None diff --git a/eole/tests/pull_request_check.sh b/eole/tests/pull_request_check.sh index 0dd0e42e..09870052 100755 --- a/eole/tests/pull_request_check.sh +++ b/eole/tests/pull_request_check.sh @@ -352,8 +352,8 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model2 \ -verbose -batch_size 10 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk -1 \ - -random_sampling_temp 0.0001 \ + -top_k -1 \ + -temperature 0.0001 \ -tgt ${DATA_DIR}/morph/tgt.valid \ -out $TMP_OUT_DIR/trans_sampling >> ${LOG_FILE} 2>&1 diff ${DATA_DIR}/morph/tgt.valid $TMP_OUT_DIR/trans_sampling @@ -389,8 +389,8 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm \ -verbose -batch_size 1 \ -beam_size 1 \ -seed 1 \ - -random_sampling_topk -1 \ - -random_sampling_temp 0.0001 \ + -top_k -1 \ + -temperature 0.0001 \ -ban_unk_token \ -length_penalty none \ -out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1 @@ -405,9 +405,9 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm \ -verbose -batch_size 1 \ -beam_size 1 \ -seed 3 \ - -random_sampling_topk -1 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k -1 \ + -top_p 0.95 \ + -temperature 1 \ -ban_unk_token \ -length_penalty none \ -out $TMP_OUT_DIR/gen_sampling >> ${LOG_FILE} 2>&1 @@ -422,9 +422,9 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm \ -verbose -batch_size 1 \ -beam_size 10 \ -seed 2 \ - -random_sampling_topk 50 \ - -random_sampling_topp 0.95 \ - -random_sampling_temp 1 \ + -top_k 50 \ + -top_p 0.95 \ + -temperature 1 \ -length_penalty avg \ -ban_unk_token \ -min_length 5 \ diff --git a/eole/utils/distributed.py b/eole/utils/distributed.py index b35e38f8..a77d7484 100644 --- a/eole/utils/distributed.py +++ b/eole/utils/distributed.py @@ -186,7 +186,9 @@ def spawned_train(process_fn, config, device_id, error_queue): # noqa: E501 error_queue.put((config.training.gpu_ranks[device_id], traceback.format_exc())) -def spawned_infer(config, device_id, error_queue, queue_instruct, queue_result): +def spawned_infer( + config, device_id, error_queue, queue_instruct, queue_result, queue_settings=None +): """Run various functions for prediction in spawned process on `device_id`.""" try: running_config = ( @@ -205,6 +207,9 @@ def spawned_infer(config, device_id, error_queue, queue_instruct, queue_result): transforms = make_transforms(config, transforms_cls, predictor.vocabs) while True: instruction = queue_instruct.get() + if queue_settings is not None: + settings = queue_settings.get() + predictor.update_settings(**settings) if instruction[0] == "stop": break elif instruction[0] == "infer_list": diff --git a/eole/utils/loss.py b/eole/utils/loss.py index 7b2b79d2..a9eb81cb 100644 --- a/eole/utils/loss.py +++ b/eole/utils/loss.py @@ -98,13 +98,13 @@ def from_config(cls, config, model, vocab, train=True): if config.training.lm_prior_model: if config.training.lm_prior_model[-3:] == ".pt": # TODO: we should probably find a way around this - config.gpu = 0 + # config.gpu = 0 config.fp32 = False config.int8 = False _, lm_prior_model, lm_model_config = DecoderModel.load_test_model( config, model_path=config.training.lm_prior_model ) # lm_model_config does not seem used - lm_prior_model.to(torch.device("cuda", config.training.gpu)) + # lm_prior_model.to(torch.device("cuda", config.training.gpu)) lm_prior_model.eval() lm_generator = None else: diff --git a/eole/utils/misc.py b/eole/utils/misc.py index 6ef2cc47..1732bbbd 100644 --- a/eole/utils/misc.py +++ b/eole/utils/misc.py @@ -69,9 +69,7 @@ def use_gpu(config): """ Creates a boolean if gpu used """ - return (hasattr(config, "gpu_ranks") and len(config.gpu_ranks) > 0) or ( - hasattr(config, "gpu") and config.gpu > -1 - ) + return hasattr(config, "gpu_ranks") and len(config.gpu_ranks) > 0 def set_random_seed(seed, is_cuda): diff --git a/eole/utils/scoring_utils.py b/eole/utils/scoring_utils.py index 379bb3bd..e90af3aa 100644 --- a/eole/utils/scoring_utils.py +++ b/eole/utils/scoring_utils.py @@ -50,7 +50,7 @@ def translate(self, model, gpu_rank, step): # (take 'inference' field of config if exists?) # Set "default" translation options on empty cfgfile predict_config = PredictConfig(model_path=["dummy"], src="dummy") - predict_config.gpu = gpu_rank + # predict_config.gpu = gpu_rank if predict_config.transforms_configs.prefix.tgt_prefix != "": predict_config.tgt_file_prefix = True predict_config.beam_size = 1 # prevent OOM when GPU is almost full at training @@ -66,6 +66,7 @@ def translate(self, model, gpu_rank, step): self.vocabs, predict_config, model_config, + device_id=gpu_rank, global_scorer=scorer, out_file=out_file, report_align=predict_config.report_align, @@ -99,7 +100,7 @@ def translate(self, model, gpu_rank, step): translator.vocabs, task=CorpusTask.INFER, tgt="", # This force to clear the target side (needed when using tgt_file_prefix) - device_id=predict_config.gpu, + device_id=gpu_rank, ) # ########### # diff --git a/recipes/gpt2/inference.yaml b/recipes/gpt2/inference.yaml index bd1e2198..68cce737 100644 --- a/recipes/gpt2/inference.yaml +++ b/recipes/gpt2/inference.yaml @@ -7,17 +7,15 @@ transforms_configs: world_size: 1 gpu_ranks: [0] -gpu: 0 model_path: ${EOLE_MODEL_DIR}/openai_gpt2 src: lm_input.txt output: lm_pred.txt beam_size: 5 -# random_sampling_topp: 0.5 -random_sampling_temp: 1.0 -random_sampling_topk: 50 -random_sampling_topp: 1 +temperature: 1.0 +top_k: 50 +top_p: 1 n_best: 5 seed: 42 diff --git a/recipes/llama2/llama-inference-tp-2gpu.yaml b/recipes/llama2/llama-inference-tp-2gpu.yaml index 146bdcf5..4f4d7929 100755 --- a/recipes/llama2/llama-inference-tp-2gpu.yaml +++ b/recipes/llama2/llama-inference-tp-2gpu.yaml @@ -10,7 +10,6 @@ model_path: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf/model.pt" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 8 world_size: 2 @@ -19,9 +18,9 @@ parallel_mode: "tensor_parallel" quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] quant_type: "bnb_NF4" precision: fp16 -random_sampling_topk: 5 -random_sampling_topp: 0.8 -random_sampling_temp: 0.9 +top_k: 5 +top_p: 0.8 +temperature: 0.9 beam_size: 1 n_best: 1 report_time: true diff --git a/recipes/llama2/llama-inference.yaml b/recipes/llama2/llama-inference.yaml index 3105987b..7113db4f 100755 --- a/recipes/llama2/llama-inference.yaml +++ b/recipes/llama2/llama-inference.yaml @@ -10,7 +10,6 @@ model_path: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 8 world_size: 1 @@ -19,9 +18,9 @@ gpu_ranks: [0] quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] quant_type: "bnb_NF4" precision: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.0 -#random_sampling_temp: 0.9 +#top_k: 1 +#top_p: 0.0 +#temperature: 0.9 beam_size: 1 n_best: 1 report_time: true diff --git a/recipes/llama3/llama-inference.yaml b/recipes/llama3/llama-inference.yaml index d0f0611f..6c70224d 100755 --- a/recipes/llama3/llama-inference.yaml +++ b/recipes/llama3/llama-inference.yaml @@ -15,7 +15,6 @@ model_path: "${EOLE_MODEL_DIR}/llama3-8b-instruct" seed: 42 max_length: 256 # max_length: 1 -gpu: 0 batch_type: sents batch_size: 4 world_size: 1 @@ -23,12 +22,12 @@ gpu_ranks: [0] # world_size: 2 # gpu_ranks: [0, 1] # parallel_mode: "tensor_parallel" -# quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] -# quant_type: "bnb_NF4" +quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] +quant_type: "bnb_NF4" precision: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.0 -#random_sampling_temp: 0.9 +#top_k: 1 +#top_p: 0.0 +#temperature: 0.9 beam_size: 1 n_best: 1 report_time: true diff --git a/recipes/llama3/llama-mmlu.yaml b/recipes/llama3/llama-mmlu.yaml index 54ba12dc..1b8a28d0 100755 --- a/recipes/llama3/llama-mmlu.yaml +++ b/recipes/llama3/llama-mmlu.yaml @@ -15,7 +15,6 @@ model_path: "${EOLE_MODEL_DIR}/llama3-8b-instruct/model.pt" seed: 42 # max_length: 256 max_length: 1 -gpu: 0 batch_type: sents batch_size: 1 world_size: 1 @@ -26,9 +25,6 @@ gpu_ranks: [0] # quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] # quant_type: "bnb_NF4" precision: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.0 -#random_sampling_temp: 0.9 beam_size: 1 n_best: 1 report_time: true diff --git a/recipes/mistral/mistral-7b-awq-gemm-inference.yaml b/recipes/mistral/mistral-7b-awq-gemm-inference.yaml index b88f9534..56f414fb 100755 --- a/recipes/mistral/mistral-7b-awq-gemm-inference.yaml +++ b/recipes/mistral/mistral-7b-awq-gemm-inference.yaml @@ -10,7 +10,6 @@ model_path: "$EOLE_MODEL_DIR/mistral-7b-instruct-v0.2-awq" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 8 world_size: 1 @@ -21,9 +20,9 @@ gpu_ranks: [0] #quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] #quant_type: "bnb_NF4" precision: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.6 -#random_sampling_temp: 0.9 +#top_k: 1 +#top_p: 0.6 +#temperature: 0.9 beam_size: 1 n_best: 1 profile: false diff --git a/recipes/mixtral/mixtral-inference-awq.yaml b/recipes/mixtral/mixtral-inference-awq.yaml index b2daa4d5..7cfb006b 100755 --- a/recipes/mixtral/mixtral-inference-awq.yaml +++ b/recipes/mixtral/mixtral-inference-awq.yaml @@ -10,7 +10,6 @@ model_path: "${EOLE_MODEL_DIR}/mixtral-8x7b-instruct-v0.1-awq" # Inference seed: 42 max_length: 256 -gpu: 0 batch_type: sents batch_size: 1 world_size: 2 @@ -20,9 +19,9 @@ parallel_mode: "tensor_parallel" #quant_layers: ['gate_up_proj', 'down_proj', 'up_proj'] #quant_type: "bnb_sparse" precision: fp16 -#random_sampling_topk: 1 -#random_sampling_topp: 0.6 -#random_sampling_temp: 0.9 +#top_k: 1 +#top_p: 0.6 +#temperature: 0.9 beam_size: 1 n_best: 1 profile: false diff --git a/recipes/wiki_103/inference.yaml b/recipes/wiki_103/inference.yaml index c7532183..42364071 100644 --- a/recipes/wiki_103/inference.yaml +++ b/recipes/wiki_103/inference.yaml @@ -10,10 +10,9 @@ transforms_configs: verbose: false n_best: 3 -random_sampling_topp: 0.9 +top_p: 0.9 beam_size: 10 -gpu: 0 world_size: 1 gpu_ranks: [0]