diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py index 05f91f43a..4d2e23b92 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -324,7 +324,7 @@ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_ tokenizer.padding_side = 'left' model_inputs = tokenizer(queries, return_tensors='pt', padding=True) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device(self.language_model.device if torch.cuda.is_available() else 'cpu') input_ids = model_inputs['input_ids'].to(device) attention_mask = model_inputs['attention_mask'].to(device) eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) @@ -374,7 +374,7 @@ def chat(self, tokenizer, pixel_values, question, generation_config, history=Non query = query.replace('', image_tokens, 1) model_inputs = tokenizer(query, return_tensors='pt') - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device(self.language_model.device if torch.cuda.is_available() else 'cpu') input_ids = model_inputs['input_ids'].to(device) attention_mask = model_inputs['attention_mask'].to(device) generation_config['eos_token_id'] = eos_token_id