ai_service.rs (24406B)
1 use async_openai::{ 2 config::OpenAIConfig, 3 types::{ 4 ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, 5 ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage, 6 ChatCompletionRequestUserMessageContent, CreateChatCompletionRequestArgs, 7 }, 8 Client as OpenAIClient, 9 }; 10 11 use serde::{Deserialize, Serialize}; 12 use std::collections::HashMap; 13 use std::sync::{Arc, Mutex}; 14 use tokio::time::{timeout, Duration}; 15 16 #[derive(Debug, Clone, Serialize, Deserialize)] 17 pub struct LanguageDetection { 18 pub language: String, 19 pub confidence: f64, 20 pub iso_code: String, 21 } 22 23 #[derive(Debug, Clone, Serialize, Deserialize)] 24 pub struct SentimentAnalysis { 25 pub sentiment: String, // "positive", "negative", "neutral" 26 pub confidence: f64, 27 pub score: f64, // -1.0 to 1.0 28 pub emotions: Vec<String>, // anger, joy, fear, sadness, etc. 29 } 30 31 #[derive(Debug, Clone, Serialize, Deserialize)] 32 pub struct MessageSummary { 33 pub summary: String, 34 pub key_points: Vec<String>, 35 pub participants: Vec<String>, 36 pub topics: Vec<String>, 37 pub sentiment_overview: String, 38 } 39 40 #[derive(Debug, Clone, Serialize, Deserialize)] 41 pub struct ModerationResult { 42 pub should_moderate: bool, 43 pub severity: u8, // 0-10 44 pub reasons: Vec<String>, 45 pub suggested_action: String, // "none", "warn", "kick", "ban" 46 pub confidence: f64, 47 } 48 49 #[derive(Debug, Clone)] 50 pub struct ChatMessage { 51 pub author: String, 52 pub content: String, 53 pub is_pm: bool, 54 } 55 56 pub struct AIService { 57 client: Option<OpenAIClient<OpenAIConfig>>, 58 message_history: Arc<Mutex<Vec<ChatMessage>>>, 59 language_cache: Arc<Mutex<HashMap<String, LanguageDetection>>>, 60 sentiment_cache: Arc<Mutex<HashMap<String, SentimentAnalysis>>>, 61 max_history: usize, 62 } 63 64 impl AIService { 65 pub fn new() -> Self { 66 let client = std::env::var("OPENAI_API_KEY").ok().map(|api_key| { 67 let config = OpenAIConfig::new().with_api_key(api_key); 68 OpenAIClient::with_config(config) 69 }); 70 71 Self { 72 client, 73 message_history: Arc::new(Mutex::new(Vec::new())), 74 language_cache: Arc::new(Mutex::new(HashMap::new())), 75 sentiment_cache: Arc::new(Mutex::new(HashMap::new())), 76 max_history: 1000, 77 } 78 } 79 80 pub fn is_available(&self) -> bool { 81 self.client.is_some() 82 } 83 84 /// Check if AI can actually be used (not just configured) 85 /// This tests for credit exhaustion and API availability 86 pub async fn is_functional(&self) -> bool { 87 if !self.is_available() { 88 return false; 89 } 90 91 self.test_api_connection().await.unwrap_or(false) 92 } 93 94 /// Test API connection with minimal request to check for credit/quota issues 95 async fn test_api_connection(&self) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> { 96 if let Some(client) = &self.client { 97 let request = CreateChatCompletionRequestArgs::default() 98 .max_tokens(1u16) 99 .model("gpt-3.5-turbo") 100 .messages([ChatCompletionRequestMessage::System( 101 ChatCompletionRequestSystemMessage { 102 content: ChatCompletionRequestSystemMessageContent::Text( 103 "test".to_string(), 104 ), 105 name: None, 106 }, 107 )]) 108 .build()?; 109 110 match timeout(Duration::from_secs(5), client.chat().create(request)).await { 111 Ok(Ok(_)) => Ok(true), 112 Ok(Err(e)) => { 113 // Check for specific credit exhaustion errors 114 let error_msg = e.to_string().to_lowercase(); 115 if error_msg.contains("quota") 116 || error_msg.contains("credit") 117 || error_msg.contains("billing") 118 || error_msg.contains("insufficient") 119 || error_msg.contains("exceeded") 120 { 121 log::warn!("AI service unavailable due to credit/quota issues: {}", e); 122 Ok(false) 123 } else { 124 log::error!("AI service test failed: {}", e); 125 Ok(false) 126 } 127 } 128 Err(_) => { 129 log::warn!("AI service test timed out"); 130 Ok(false) 131 } 132 } 133 } else { 134 Ok(false) 135 } 136 } 137 138 pub fn add_message(&self, message: ChatMessage) { 139 let mut history = self.message_history.lock().unwrap(); 140 history.push(message); 141 142 // Keep only the last max_history messages 143 if history.len() > self.max_history { 144 let excess = history.len() - self.max_history; 145 history.drain(0..excess); 146 } 147 } 148 149 pub fn get_recent_messages(&self, count: usize) -> Vec<ChatMessage> { 150 let history = self.message_history.lock().unwrap(); 151 let start = if history.len() > count { 152 history.len() - count 153 } else { 154 0 155 }; 156 history[start..].to_vec() 157 } 158 159 pub async fn detect_language(&self, text: &str) -> Option<LanguageDetection> { 160 // Check cache first 161 { 162 let cache = self.language_cache.lock().unwrap(); 163 if let Some(detection) = cache.get(text) { 164 return Some(detection.clone()); 165 } 166 } 167 168 let client = self.client.as_ref()?; 169 170 let prompt = format!( 171 "Detect the language of the following text and return only a JSON response in this exact format: 172 {{ 173 \"language\": \"language_name\", 174 \"confidence\": 0.95, 175 \"iso_code\": \"ISO_639-1_code\" 176 }} 177 178 Text to analyze: \"{}\" 179 180 Important: Return ONLY the JSON, no other text or explanation.", 181 text.trim() 182 ); 183 184 let result = timeout(Duration::from_secs(10), async { 185 let request = CreateChatCompletionRequestArgs::default() 186 .model("gpt-3.5-turbo") 187 .messages([ 188 ChatCompletionRequestMessage::System( 189 ChatCompletionRequestSystemMessage { 190 content: ChatCompletionRequestSystemMessageContent::Text( 191 "You are a language detection assistant. Always respond with valid JSON only.".to_string() 192 ), 193 name: None, 194 } 195 ), 196 ChatCompletionRequestMessage::User( 197 ChatCompletionRequestUserMessage { 198 content: ChatCompletionRequestUserMessageContent::Text(prompt), 199 name: None, 200 } 201 ), 202 ]) 203 .max_tokens(100u16) 204 .temperature(0.1) 205 .build()?; 206 207 client.chat().create(request).await 208 }).await; 209 210 match result { 211 Ok(Ok(response)) => { 212 if let Some(choice) = response.choices.first() { 213 if let Some(content) = &choice.message.content { 214 if let Ok(detection) = 215 serde_json::from_str::<LanguageDetection>(content.trim()) 216 { 217 // Cache the result 218 { 219 let mut cache = self.language_cache.lock().unwrap(); 220 cache.insert(text.to_string(), detection.clone()); 221 222 // Limit cache size 223 if cache.len() > 100 { 224 let keys: Vec<String> = 225 cache.keys().take(20).cloned().collect(); 226 for key in keys { 227 cache.remove(&key); 228 } 229 } 230 } 231 return Some(detection); 232 } 233 } 234 } 235 None 236 } 237 _ => None, 238 } 239 } 240 241 pub async fn analyze_sentiment(&self, text: &str) -> Option<SentimentAnalysis> { 242 // Check cache first 243 { 244 let cache = self.sentiment_cache.lock().unwrap(); 245 if let Some(analysis) = cache.get(text) { 246 return Some(analysis.clone()); 247 } 248 } 249 250 let client = self.client.as_ref()?; 251 252 let prompt = format!( 253 "Analyze the sentiment and emotions of the following text and return only a JSON response in this exact format: 254 {{ 255 \"sentiment\": \"positive|negative|neutral\", 256 \"confidence\": 0.85, 257 \"score\": 0.3, 258 \"emotions\": [\"joy\", \"excitement\"] 259 }} 260 261 Text to analyze: \"{}\" 262 263 Score should be between -1.0 (very negative) and 1.0 (very positive). 264 Emotions can include: joy, anger, fear, sadness, surprise, disgust, trust, anticipation. 265 Return ONLY the JSON, no other text.", 266 text.trim() 267 ); 268 269 let result = timeout(Duration::from_secs(10), async { 270 let request = CreateChatCompletionRequestArgs::default() 271 .model("gpt-3.5-turbo") 272 .messages([ 273 ChatCompletionRequestMessage::System( 274 ChatCompletionRequestSystemMessage { 275 content: ChatCompletionRequestSystemMessageContent::Text( 276 "You are a sentiment analysis assistant. Always respond with valid JSON only.".to_string() 277 ), 278 name: None, 279 } 280 ), 281 ChatCompletionRequestMessage::User( 282 ChatCompletionRequestUserMessage { 283 content: ChatCompletionRequestUserMessageContent::Text(prompt), 284 name: None, 285 } 286 ), 287 ]) 288 .max_tokens(150u16) 289 .temperature(0.1) 290 .build()?; 291 292 client.chat().create(request).await 293 }).await; 294 295 match result { 296 Ok(Ok(response)) => { 297 if let Some(choice) = response.choices.first() { 298 if let Some(content) = &choice.message.content { 299 if let Ok(analysis) = 300 serde_json::from_str::<SentimentAnalysis>(content.trim()) 301 { 302 // Cache the result 303 { 304 let mut cache = self.sentiment_cache.lock().unwrap(); 305 cache.insert(text.to_string(), analysis.clone()); 306 307 // Limit cache size 308 if cache.len() > 100 { 309 let keys: Vec<String> = 310 cache.keys().take(20).cloned().collect(); 311 for key in keys { 312 cache.remove(&key); 313 } 314 } 315 } 316 return Some(analysis); 317 } 318 } 319 } 320 None 321 } 322 _ => None, 323 } 324 } 325 326 pub async fn summarize_chat(&self, message_count: Option<usize>) -> Option<MessageSummary> { 327 let client = self.client.as_ref()?; 328 329 let count = message_count.unwrap_or(50); 330 let messages = self.get_recent_messages(count); 331 332 if messages.is_empty() { 333 return None; 334 } 335 336 let chat_text = messages 337 .iter() 338 .map(|msg| { 339 if msg.is_pm { 340 format!("[PM] {}: {}", msg.author, msg.content) 341 } else { 342 format!("{}: {}", msg.author, msg.content) 343 } 344 }) 345 .collect::<Vec<_>>() 346 .join("\n"); 347 348 let prompt = format!( 349 "Analyze and summarize the following chat conversation. Return only a JSON response in this exact format: 350 {{ 351 \"summary\": \"Brief summary of the conversation\", 352 \"key_points\": [\"Point 1\", \"Point 2\"], 353 \"participants\": [\"user1\", \"user2\"], 354 \"topics\": [\"topic1\", \"topic2\"], 355 \"sentiment_overview\": \"Overall mood description\" 356 }} 357 358 Chat messages: 359 {} 360 361 Return ONLY the JSON, no other text.", 362 chat_text 363 ); 364 365 let result = timeout(Duration::from_secs(15), async { 366 let request = CreateChatCompletionRequestArgs::default() 367 .model("gpt-3.5-turbo") 368 .messages([ 369 ChatCompletionRequestMessage::System( 370 ChatCompletionRequestSystemMessage { 371 content: ChatCompletionRequestSystemMessageContent::Text( 372 "You are a chat summarization assistant. Always respond with valid JSON only.".to_string() 373 ), 374 name: None, 375 } 376 ), 377 ChatCompletionRequestMessage::User( 378 ChatCompletionRequestUserMessage { 379 content: ChatCompletionRequestUserMessageContent::Text(prompt), 380 name: None, 381 } 382 ), 383 ]) 384 .max_tokens(400u16) 385 .temperature(0.3) 386 .build()?; 387 388 client.chat().create(request).await 389 }).await; 390 391 match result { 392 Ok(Ok(response)) => { 393 if let Some(choice) = response.choices.first() { 394 if let Some(content) = &choice.message.content { 395 if let Ok(summary) = serde_json::from_str::<MessageSummary>(content.trim()) 396 { 397 return Some(summary); 398 } 399 } 400 } 401 None 402 } 403 _ => None, 404 } 405 } 406 407 pub async fn advanced_moderation(&self, text: &str, context: &str) -> Option<ModerationResult> { 408 let client = self.client.as_ref()?; 409 410 let prompt = format!( 411 "Analyze the following message for harmful content, considering the chat context. Return only a JSON response in this exact format: 412 {{ 413 \"should_moderate\": false, 414 \"severity\": 3, 415 \"reasons\": [\"reason1\", \"reason2\"], 416 \"suggested_action\": \"warn\", 417 \"confidence\": 0.85 418 }} 419 420 Message to analyze: \"{}\" 421 Chat context: \"{}\" 422 423 Severity scale: 0 (harmless) to 10 (extremely harmful) 424 Suggested actions: \"none\", \"warn\", \"kick\", \"ban\" 425 Consider: harassment, hate speech, spam, threats, inappropriate content, but also context and intent. 426 Return ONLY the JSON, no other text.", 427 text.trim(), 428 context.trim() 429 ); 430 431 let result = timeout(Duration::from_secs(12), async { 432 let request = CreateChatCompletionRequestArgs::default() 433 .model("gpt-3.5-turbo") 434 .messages([ 435 ChatCompletionRequestMessage::System( 436 ChatCompletionRequestSystemMessage { 437 content: ChatCompletionRequestSystemMessageContent::Text( 438 "You are a content moderation assistant for an anonymous chat. Be balanced - not too strict but protect users from genuine harm. Always respond with valid JSON only.".to_string() 439 ), 440 name: None, 441 } 442 ), 443 ChatCompletionRequestMessage::User( 444 ChatCompletionRequestUserMessage { 445 content: ChatCompletionRequestUserMessageContent::Text(prompt), 446 name: None, 447 } 448 ), 449 ]) 450 .max_tokens(200u16) 451 .temperature(0.2) 452 .build()?; 453 454 client.chat().create(request).await 455 }).await; 456 457 match result { 458 Ok(Ok(response)) => { 459 if let Some(choice) = response.choices.first() { 460 if let Some(content) = &choice.message.content { 461 if let Ok(moderation) = 462 serde_json::from_str::<ModerationResult>(content.trim()) 463 { 464 return Some(moderation); 465 } 466 } 467 } 468 None 469 } 470 _ => None, 471 } 472 } 473 474 pub async fn translate_text(&self, text: &str, target_language: &str) -> Option<String> { 475 let client = self.client.as_ref()?; 476 477 let prompt = format!( 478 "Translate the following text to {} and return ONLY the translated text, no explanations or quotes: 479 480 Text to translate: \"{}\"", 481 target_language, text.trim() 482 ); 483 484 let result = timeout(Duration::from_secs(10), async { 485 let request = CreateChatCompletionRequestArgs::default() 486 .model("gpt-3.5-turbo") 487 .messages([ 488 ChatCompletionRequestMessage::System( 489 ChatCompletionRequestSystemMessage { 490 content: ChatCompletionRequestSystemMessageContent::Text( 491 "You are a translation assistant. Always return only the translated text, nothing else.".to_string() 492 ), 493 name: None, 494 } 495 ), 496 ChatCompletionRequestMessage::User( 497 ChatCompletionRequestUserMessage { 498 content: ChatCompletionRequestUserMessageContent::Text(prompt), 499 name: None, 500 } 501 ), 502 ]) 503 .max_tokens(300u16) 504 .temperature(0.1) 505 .build()?; 506 507 client.chat().create(request).await 508 }).await; 509 510 match result { 511 Ok(Ok(response)) => { 512 if let Some(choice) = response.choices.first() { 513 choice.message.content.clone() 514 } else { 515 None 516 } 517 } 518 _ => None, 519 } 520 } 521 522 pub fn get_chat_atmosphere(&self) -> String { 523 let messages = self.get_recent_messages(20); 524 if messages.is_empty() { 525 return "😐 Quiet".to_string(); 526 } 527 528 let total_msgs = messages.len(); 529 let unique_users: std::collections::HashSet<_> = 530 messages.iter().map(|m| &m.author).collect(); 531 let user_count = unique_users.len(); 532 533 // Simple heuristic for activity level 534 let activity = if total_msgs > 15 { 535 "Very Active" 536 } else if total_msgs > 8 { 537 "Active" 538 } else if total_msgs > 3 { 539 "Moderate" 540 } else { 541 "Quiet" 542 }; 543 544 let diversity = if user_count > 5 { 545 "Diverse" 546 } else if user_count > 2 { 547 "Social" 548 } else { 549 "Focused" 550 }; 551 552 format!( 553 "🌊 {} & {} ({} msgs, {} users)", 554 activity, diversity, total_msgs, user_count 555 ) 556 } 557 558 pub fn get_stats(&self) -> HashMap<String, String> { 559 let mut stats = HashMap::new(); 560 let history = self.message_history.lock().unwrap(); 561 let lang_cache = self.language_cache.lock().unwrap(); 562 let sentiment_cache = self.sentiment_cache.lock().unwrap(); 563 564 stats.insert("available".to_string(), self.is_available().to_string()); 565 stats.insert("message_history".to_string(), history.len().to_string()); 566 stats.insert("language_cache".to_string(), lang_cache.len().to_string()); 567 stats.insert( 568 "sentiment_cache".to_string(), 569 sentiment_cache.len().to_string(), 570 ); 571 stats.insert("max_history".to_string(), self.max_history.to_string()); 572 573 stats 574 } 575 } 576 577 // Helper function for fallback language detection using simple heuristics 578 pub fn fallback_language_detection(text: &str) -> LanguageDetection { 579 let text_lower = text.to_lowercase(); 580 581 // Simple heuristic based on common words 582 let english_indicators = ["the", "and", "is", "are", "you", "that", "with", "for"]; 583 let spanish_indicators = ["el", "la", "es", "con", "por", "que", "una", "para"]; 584 let french_indicators = ["le", "la", "et", "est", "avec", "pour", "une", "dans"]; 585 let german_indicators = ["der", "die", "und", "ist", "mit", "für", "eine", "das"]; 586 587 let english_count = english_indicators 588 .iter() 589 .filter(|&&word| text_lower.contains(word)) 590 .count(); 591 let spanish_count = spanish_indicators 592 .iter() 593 .filter(|&&word| text_lower.contains(word)) 594 .count(); 595 let french_count = french_indicators 596 .iter() 597 .filter(|&&word| text_lower.contains(word)) 598 .count(); 599 let german_count = german_indicators 600 .iter() 601 .filter(|&&word| text_lower.contains(word)) 602 .count(); 603 604 let max_count = *[english_count, spanish_count, french_count, german_count] 605 .iter() 606 .max() 607 .unwrap_or(&0); 608 609 if max_count == 0 { 610 return LanguageDetection { 611 language: "Unknown".to_string(), 612 confidence: 0.1, 613 iso_code: "??".to_string(), 614 }; 615 } 616 617 let confidence = (max_count as f64 / text.split_whitespace().count() as f64).min(0.9); 618 619 if english_count == max_count { 620 LanguageDetection { 621 language: "English".to_string(), 622 confidence, 623 iso_code: "en".to_string(), 624 } 625 } else if spanish_count == max_count { 626 LanguageDetection { 627 language: "Spanish".to_string(), 628 confidence, 629 iso_code: "es".to_string(), 630 } 631 } else if french_count == max_count { 632 LanguageDetection { 633 language: "French".to_string(), 634 confidence, 635 iso_code: "fr".to_string(), 636 } 637 } else if german_count == max_count { 638 LanguageDetection { 639 language: "German".to_string(), 640 confidence, 641 iso_code: "de".to_string(), 642 } 643 } else { 644 LanguageDetection { 645 language: "Unknown".to_string(), 646 confidence: 0.1, 647 iso_code: "??".to_string(), 648 } 649 } 650 } 651 652 // Helper function for basic sentiment analysis without AI 653 pub fn fallback_sentiment_analysis(text: &str) -> SentimentAnalysis { 654 let text_lower = text.to_lowercase(); 655 656 let positive_words = [ 657 "good", 658 "great", 659 "awesome", 660 "amazing", 661 "happy", 662 "love", 663 "excellent", 664 "perfect", 665 "wonderful", 666 ]; 667 let negative_words = [ 668 "bad", 669 "terrible", 670 "awful", 671 "hate", 672 "sad", 673 "angry", 674 "horrible", 675 "disgusting", 676 "annoying", 677 ]; 678 679 let positive_count = positive_words 680 .iter() 681 .filter(|&&word| text_lower.contains(word)) 682 .count() as f64; 683 let negative_count = negative_words 684 .iter() 685 .filter(|&&word| text_lower.contains(word)) 686 .count() as f64; 687 688 let total_words = text.split_whitespace().count() as f64; 689 let score = (positive_count - negative_count) / total_words.max(1.0); 690 691 let (sentiment, emotions) = if score > 0.1 { 692 ("positive", vec!["joy".to_string()]) 693 } else if score < -0.1 { 694 ("negative", vec!["anger".to_string()]) 695 } else { 696 ("neutral", vec![]) 697 }; 698 699 SentimentAnalysis { 700 sentiment: sentiment.to_string(), 701 confidence: 0.6, 702 score: score.clamp(-1.0, 1.0), 703 emotions, 704 } 705 }