//! Common trading strategy logic used by both the live bot and the backtester. use std::collections::HashMap; use crate::config::{ get_sector, IndicatorParams, Timeframe, ATR_STOP_MULTIPLIER, ATR_TRAIL_ACTIVATION_MULTIPLIER, ATR_TRAIL_MULTIPLIER, MAX_LOSS_PCT, MAX_POSITION_SIZE, MIN_ATR_PCT, RISK_PER_TRADE, STOP_LOSS_PCT, TIME_EXIT_BARS, TRAILING_STOP_ACTIVATION, TRAILING_STOP_DISTANCE, }; use crate::types::{Signal, TradeSignal}; /// Contains the core trading strategy logic. pub struct Strategy { pub params: IndicatorParams, pub high_water_marks: HashMap, pub entry_atrs: HashMap, pub entry_prices: HashMap, } impl Strategy { pub fn new(timeframe: Timeframe) -> Self { Self { params: timeframe.params(), high_water_marks: HashMap::new(), entry_atrs: HashMap::new(), entry_prices: HashMap::new(), } } /// Volatility-adjusted position sizing using ATR. pub fn calculate_position_size( &self, price: f64, portfolio_value: f64, available_cash: f64, signal: &TradeSignal, ) -> u64 { if available_cash <= 0.0 { return 0; } let position_value = if signal.atr_pct > MIN_ATR_PCT { let atr_stop_pct = signal.atr_pct * ATR_STOP_MULTIPLIER; let risk_amount = portfolio_value * RISK_PER_TRADE; let vol_adjusted = risk_amount / atr_stop_pct; // Scale by confidence let confidence_scale = 0.7 + 0.3 * signal.confidence; let sized = vol_adjusted * confidence_scale; sized.min(portfolio_value * MAX_POSITION_SIZE) } else { portfolio_value * MAX_POSITION_SIZE }; let position_value = position_value.min(available_cash); (position_value / price).floor() as u64 } /// Check if stop-loss, trailing stop, or time exit should trigger. pub fn check_stop_loss_take_profit( &mut self, symbol: &str, current_price: f64, bars_held: usize, ) -> Option { let entry_price = match self.entry_prices.get(symbol) { Some(&p) => p, None => return None, }; let pnl_pct = (current_price - entry_price) / entry_price; let entry_atr = self.entry_atrs.get(symbol).copied().unwrap_or(0.0); // Update high water mark if let Some(hwm) = self.high_water_marks.get_mut(symbol) { if current_price > *hwm { *hwm = current_price; } } // Hard max-loss cap if pnl_pct <= -MAX_LOSS_PCT { return Some(Signal::StrongSell); } // ATR-based stop loss if entry_atr > 0.0 { let atr_stop_price = entry_price - ATR_STOP_MULTIPLIER * entry_atr; if current_price <= atr_stop_price { return Some(Signal::StrongSell); } } else if pnl_pct <= -STOP_LOSS_PCT { return Some(Signal::StrongSell); } // Time-based exit if bars_held >= TIME_EXIT_BARS { let activation = if entry_atr > 0.0 { (ATR_TRAIL_ACTIVATION_MULTIPLIER * entry_atr) / entry_price } else { TRAILING_STOP_ACTIVATION }; if pnl_pct < activation { return Some(Signal::Sell); } } // ATR-based trailing stop let activation_gain = if entry_atr > 0.0 { (ATR_TRAIL_ACTIVATION_MULTIPLIER * entry_atr) / entry_price } else { TRAILING_STOP_ACTIVATION }; if pnl_pct >= activation_gain { if let Some(&high_water) = self.high_water_marks.get(symbol) { let trail_distance = if entry_atr > 0.0 { ATR_TRAIL_MULTIPLIER * entry_atr } else { high_water * TRAILING_STOP_DISTANCE }; let trailing_stop_price = high_water - trail_distance; if current_price <= trailing_stop_price { return Some(Signal::Sell); } } } None } /// Count positions in a given sector. pub fn sector_position_count<'a, I>(&self, sector: &str, positions: I) -> usize where I: IntoIterator, { positions .into_iter() .filter(|sym| get_sector(sym) == sector) .count() } }