Files
vibe-invest/src/strategy.rs
2026-02-11 18:00:12 +00:00

141 lines
4.5 KiB
Rust

//! Common trading strategy logic used by both the live bot and the backtester.
use std::collections::HashMap;
use crate::config::{
get_sector, IndicatorParams, Timeframe, ATR_STOP_MULTIPLIER,
ATR_TRAIL_ACTIVATION_MULTIPLIER, ATR_TRAIL_MULTIPLIER, MAX_LOSS_PCT, MAX_POSITION_SIZE,
MIN_ATR_PCT, RISK_PER_TRADE, STOP_LOSS_PCT, TIME_EXIT_BARS,
TRAILING_STOP_ACTIVATION, TRAILING_STOP_DISTANCE,
};
use crate::types::{Signal, TradeSignal};
/// Contains the core trading strategy logic.
pub struct Strategy {
pub params: IndicatorParams,
pub high_water_marks: HashMap<String, f64>,
pub entry_atrs: HashMap<String, f64>,
pub entry_prices: HashMap<String, f64>,
}
impl Strategy {
pub fn new(timeframe: Timeframe) -> Self {
Self {
params: timeframe.params(),
high_water_marks: HashMap::new(),
entry_atrs: HashMap::new(),
entry_prices: HashMap::new(),
}
}
/// Volatility-adjusted position sizing using ATR.
pub fn calculate_position_size(
&self,
price: f64,
portfolio_value: f64,
available_cash: f64,
signal: &TradeSignal,
) -> u64 {
if available_cash <= 0.0 {
return 0;
}
let position_value = if signal.atr_pct > MIN_ATR_PCT {
let atr_stop_pct = signal.atr_pct * ATR_STOP_MULTIPLIER;
let risk_amount = portfolio_value * RISK_PER_TRADE;
let vol_adjusted = risk_amount / atr_stop_pct;
// Scale by confidence
let confidence_scale = 0.7 + 0.3 * signal.confidence;
let sized = vol_adjusted * confidence_scale;
sized.min(portfolio_value * MAX_POSITION_SIZE)
} else {
portfolio_value * MAX_POSITION_SIZE
};
let position_value = position_value.min(available_cash);
(position_value / price).floor() as u64
}
/// Check if stop-loss, trailing stop, or time exit should trigger.
pub fn check_stop_loss_take_profit(
&mut self,
symbol: &str,
current_price: f64,
bars_held: usize,
) -> Option<Signal> {
let entry_price = match self.entry_prices.get(symbol) {
Some(&p) => p,
None => return None,
};
let pnl_pct = (current_price - entry_price) / entry_price;
let entry_atr = self.entry_atrs.get(symbol).copied().unwrap_or(0.0);
// Update high water mark
if let Some(hwm) = self.high_water_marks.get_mut(symbol) {
if current_price > *hwm {
*hwm = current_price;
}
}
// Hard max-loss cap
if pnl_pct <= -MAX_LOSS_PCT {
return Some(Signal::StrongSell);
}
// ATR-based stop loss
if entry_atr > 0.0 {
let atr_stop_price = entry_price - ATR_STOP_MULTIPLIER * entry_atr;
if current_price <= atr_stop_price {
return Some(Signal::StrongSell);
}
} else if pnl_pct <= -STOP_LOSS_PCT {
return Some(Signal::StrongSell);
}
// Time-based exit
if bars_held >= TIME_EXIT_BARS {
let activation = if entry_atr > 0.0 {
(ATR_TRAIL_ACTIVATION_MULTIPLIER * entry_atr) / entry_price
} else {
TRAILING_STOP_ACTIVATION
};
if pnl_pct < activation {
return Some(Signal::Sell);
}
}
// ATR-based trailing stop
let activation_gain = if entry_atr > 0.0 {
(ATR_TRAIL_ACTIVATION_MULTIPLIER * entry_atr) / entry_price
} else {
TRAILING_STOP_ACTIVATION
};
if pnl_pct >= activation_gain {
if let Some(&high_water) = self.high_water_marks.get(symbol) {
let trail_distance = if entry_atr > 0.0 {
ATR_TRAIL_MULTIPLIER * entry_atr
} else {
high_water * TRAILING_STOP_DISTANCE
};
let trailing_stop_price = high_water - trail_distance;
if current_price <= trailing_stop_price {
return Some(Signal::Sell);
}
}
}
None
}
/// Count positions in a given sector.
pub fn sector_position_count<'a, I>(&self, sector: &str, positions: I) -> usize
where
I: IntoIterator<Item = &'a String>,
{
positions
.into_iter()
.filter(|sym| get_sector(sym) == sector)
.count()
}
}