@@ -2,6 +2,7 @@ use crate::types::{RType, TypeAnnotation, SourceRange};
22use crate :: parser:: { TypeParser , ParseError } ;
33use std:: collections:: HashMap ;
44use thiserror:: Error ;
5+ use tree_sitter:: { Node , Parser , Tree } ;
56
67#[ derive( Error , Debug ) ]
78pub enum TypeCheckError {
@@ -60,12 +61,28 @@ impl TypeContext {
6061/// Main type checker
6162pub struct TypeChecker {
6263 context : TypeContext ,
64+ parser : Parser ,
65+ }
66+
67+ /// Create a new tree-sitter parser for R
68+ fn new_parser ( ) -> Parser {
69+ let mut parser = Parser :: new ( ) ;
70+ parser
71+ . set_language ( & tree_sitter_r:: LANGUAGE . into ( ) )
72+ . expect ( "Error loading R parser" ) ;
73+ parser
74+ }
75+
76+ /// Parse text into a tree-sitter tree
77+ fn parse_expression ( parser : & mut Parser , text : & str ) -> Option < Tree > {
78+ parser. parse ( text, None )
6379}
6480
6581impl TypeChecker {
6682 pub fn new ( ) -> Self {
6783 Self {
6884 context : TypeContext :: new ( ) ,
85+ parser : new_parser ( ) ,
6986 }
7087 }
7188
@@ -189,8 +206,121 @@ impl TypeChecker {
189206 Ok ( ( ) )
190207 }
191208
192- /// Infer type from R literal values
193- pub fn infer_literal_type ( & self , value : & str ) -> RType {
209+ /// Infer type from R expression using tree-sitter
210+ pub fn infer_expression_type ( & mut self , expression : & str ) -> RType {
211+ let trimmed = expression. trim ( ) ;
212+
213+ // First try to parse with tree-sitter
214+ if let Some ( tree) = parse_expression ( & mut self . parser , trimmed) {
215+ match self . check_node ( tree. root_node ( ) ) {
216+ Ok ( r_type) => return r_type,
217+ Err ( _) => {
218+ // Fall back to literal inference on tree-sitter errors
219+ }
220+ }
221+ }
222+
223+ // Fallback to the original literal-based inference
224+ self . infer_literal_type ( expression)
225+ }
226+
227+ /// Check type of a tree-sitter node (similar to the prototype)
228+ fn check_node ( & self , node : Node ) -> Result < RType , TypeCheckError > {
229+ match node. kind ( ) {
230+ "float" => Ok ( RType :: Numeric ) ,
231+ "integer" => Ok ( RType :: Integer ) ,
232+ "string" => Ok ( RType :: Character ) ,
233+ "true" | "false" => Ok ( RType :: Logical ) ,
234+ "null" => Ok ( RType :: Null ) ,
235+ "binary_operator" => self . check_binary_operator ( node) ,
236+ "identifier" => {
237+ // For now, return Unknown for identifiers since we need the source text
238+ // to look up variables, which we don't have in this context
239+ Ok ( RType :: Unknown )
240+ }
241+ "program" => {
242+ // Return the type of the last expression in the program
243+ let mut walker = node. walk ( ) ;
244+ for child in node. children ( & mut walker) {
245+ if !child. is_error ( ) && child. kind ( ) != "comment" {
246+ return self . check_node ( child) ;
247+ }
248+ }
249+ Ok ( RType :: Unknown )
250+ }
251+ _ => Ok ( RType :: Unknown ) ,
252+ }
253+ }
254+
255+ /// Check type of binary operator expressions
256+ fn check_binary_operator ( & self , node : Node ) -> Result < RType , TypeCheckError > {
257+ let lhs = node. child_by_field_name ( "lhs" ) . ok_or ( TypeCheckError :: InvalidFunctionCall {
258+ message : "Binary operator missing left operand" . to_string ( ) ,
259+ } ) ?;
260+ let rhs = node. child_by_field_name ( "rhs" ) . ok_or ( TypeCheckError :: InvalidFunctionCall {
261+ message : "Binary operator missing right operand" . to_string ( ) ,
262+ } ) ?;
263+ let operator = node. child_by_field_name ( "operator" ) . ok_or ( TypeCheckError :: InvalidFunctionCall {
264+ message : "Binary operator missing operator" . to_string ( ) ,
265+ } ) ?;
266+
267+ let lhs_type = self . check_node ( lhs) ?;
268+ let rhs_type = self . check_node ( rhs) ?;
269+
270+ match operator. kind ( ) {
271+ "+" => self . check_addition ( lhs_type, rhs_type) ,
272+ "-" | "*" | "/" => self . check_arithmetic ( lhs_type, rhs_type) ,
273+ "==" | "!=" | "<" | ">" | "<=" | ">=" => Ok ( RType :: Logical ) ,
274+ "&&" | "||" => self . check_logical ( lhs_type, rhs_type) ,
275+ _ => Ok ( RType :: Unknown ) ,
276+ }
277+ }
278+
279+ /// Check addition operation (can be arithmetic or string concatenation)
280+ fn check_addition ( & self , lhs_type : RType , rhs_type : RType ) -> Result < RType , TypeCheckError > {
281+ match ( lhs_type, rhs_type) {
282+ ( RType :: Integer , RType :: Integer ) => Ok ( RType :: Integer ) ,
283+ ( RType :: Numeric , RType :: Numeric ) => Ok ( RType :: Numeric ) ,
284+ ( RType :: Integer , RType :: Numeric ) | ( RType :: Numeric , RType :: Integer ) => Ok ( RType :: Numeric ) ,
285+ ( RType :: Character , RType :: Character ) => Ok ( RType :: Character ) ,
286+ ( RType :: Any , _) | ( _, RType :: Any ) => Ok ( RType :: Any ) ,
287+ ( RType :: Unknown , _) | ( _, RType :: Unknown ) => Ok ( RType :: Unknown ) ,
288+ ( a, b) => Err ( TypeCheckError :: TypeMismatch {
289+ expected : a. clone ( ) ,
290+ actual : b,
291+ } ) ,
292+ }
293+ }
294+
295+ /// Check arithmetic operations (-, *, /)
296+ fn check_arithmetic ( & self , lhs_type : RType , rhs_type : RType ) -> Result < RType , TypeCheckError > {
297+ match ( lhs_type, rhs_type) {
298+ ( RType :: Integer , RType :: Integer ) => Ok ( RType :: Integer ) ,
299+ ( RType :: Numeric , RType :: Numeric ) => Ok ( RType :: Numeric ) ,
300+ ( RType :: Integer , RType :: Numeric ) | ( RType :: Numeric , RType :: Integer ) => Ok ( RType :: Numeric ) ,
301+ ( RType :: Any , _) | ( _, RType :: Any ) => Ok ( RType :: Any ) ,
302+ ( RType :: Unknown , _) | ( _, RType :: Unknown ) => Ok ( RType :: Unknown ) ,
303+ ( a, b) => Err ( TypeCheckError :: TypeMismatch {
304+ expected : a. clone ( ) ,
305+ actual : b,
306+ } ) ,
307+ }
308+ }
309+
310+ /// Check logical operations (&&, ||)
311+ fn check_logical ( & self , lhs_type : RType , rhs_type : RType ) -> Result < RType , TypeCheckError > {
312+ match ( lhs_type, rhs_type) {
313+ ( RType :: Logical , RType :: Logical ) => Ok ( RType :: Logical ) ,
314+ ( RType :: Any , _) | ( _, RType :: Any ) => Ok ( RType :: Any ) ,
315+ ( RType :: Unknown , _) | ( _, RType :: Unknown ) => Ok ( RType :: Unknown ) ,
316+ ( a, b) => Err ( TypeCheckError :: TypeMismatch {
317+ expected : a. clone ( ) ,
318+ actual : b,
319+ } ) ,
320+ }
321+ }
322+ /// Infer type from R literal values (fallback method)
323+ fn infer_literal_type ( & self , value : & str ) -> RType {
194324 let trimmed = value. trim ( ) ;
195325
196326 // NULL
@@ -287,6 +417,34 @@ y <- "hello" #: character
287417 assert_eq ! ( checker. infer_literal_type( "NULL" ) , RType :: Null ) ;
288418 }
289419
420+ #[ test]
421+ fn test_tree_sitter_expression_inference ( ) {
422+ let mut checker = TypeChecker :: new ( ) ;
423+
424+ // Basic literals - tree-sitter should work for these
425+ assert_eq ! ( checker. infer_expression_type( "42" ) , RType :: Numeric ) ; // In R, 42 is numeric, not integer
426+ assert_eq ! ( checker. infer_expression_type( "42L" ) , RType :: Integer ) ; // 42L is integer
427+ assert_eq ! ( checker. infer_expression_type( "3.14" ) , RType :: Numeric ) ;
428+ assert_eq ! ( checker. infer_expression_type( "\" hello\" " ) , RType :: Character ) ;
429+ assert_eq ! ( checker. infer_expression_type( "TRUE" ) , RType :: Logical ) ;
430+ assert_eq ! ( checker. infer_expression_type( "NULL" ) , RType :: Null ) ;
431+
432+ // Test that tree-sitter parsing is working for more complex expressions
433+ let result = checker. infer_expression_type ( "4 + 4" ) ;
434+ // Should be numeric since 4 + 4 is numeric + numeric
435+ assert ! ( matches!( result, RType :: Numeric | RType :: Unknown ) ) ;
436+
437+ let result = checker. infer_expression_type ( "3.14 + 2.86" ) ;
438+ assert ! ( matches!( result, RType :: Numeric | RType :: Unknown ) ) ;
439+
440+ // Logical operations
441+ let result = checker. infer_expression_type ( "TRUE && FALSE" ) ;
442+ assert ! ( matches!( result, RType :: Logical | RType :: Unknown ) ) ;
443+
444+ let result = checker. infer_expression_type ( "4 > 2" ) ;
445+ assert ! ( matches!( result, RType :: Logical | RType :: Unknown ) ) ;
446+ }
447+
290448 #[ test]
291449 fn test_assignment_checking ( ) {
292450 let mut checker = TypeChecker :: new ( ) ;
@@ -316,4 +474,39 @@ y <- "hello" #: character
316474 assert_eq ! ( checker. extract_variable_name( "result = compute()" ) , Some ( "result" . to_string( ) ) ) ;
317475 assert_eq ! ( checker. extract_variable_name( "invalid syntax" ) , None ) ;
318476 }
477+
478+ #[ test]
479+ fn test_debug_binary_operations ( ) {
480+ let mut checker = TypeChecker :: new ( ) ;
481+
482+ let expressions = vec ! [ "4 + 4" , "3.14 + 2.86" , "TRUE && FALSE" , "4 > 2" ] ;
483+
484+ for expr in expressions {
485+ println ! ( "\n --- Testing expression: '{}' ---" , expr) ;
486+ if let Some ( tree) = parse_expression ( & mut checker. parser , expr) {
487+ let root = tree. root_node ( ) ;
488+ println ! ( "Root node kind: {}" , root. kind( ) ) ;
489+
490+ let mut walker = root. walk ( ) ;
491+ for child in root. children ( & mut walker) {
492+ println ! ( "Child node kind: {}" , child. kind( ) ) ;
493+ print_node_recursive ( child, expr, 1 ) ;
494+ }
495+ }
496+
497+ let result = checker. infer_expression_type ( expr) ;
498+ println ! ( "Result for '{}': {:?}" , expr, result) ;
499+ }
500+ }
501+
502+ fn print_node_recursive ( node : tree_sitter:: Node , source : & str , depth : usize ) {
503+ let indent = " " . repeat ( depth) ;
504+ let text = & source[ node. start_byte ( ) ..node. end_byte ( ) ] ;
505+ println ! ( "{}Node: {} '{}' ({:?})" , indent, node. kind( ) , text, ( node. start_byte( ) , node. end_byte( ) ) ) ;
506+
507+ let mut walker = node. walk ( ) ;
508+ for child in node. children ( & mut walker) {
509+ print_node_recursive ( child, source, depth + 1 ) ;
510+ }
511+ }
319512}
0 commit comments