first comit

This commit is contained in:
zastian-dev
2026-02-09 19:20:47 +00:00
commit 79625743bd
24 changed files with 8726 additions and 0 deletions

699
src/backtester.rs Normal file
View File

@@ -0,0 +1,699 @@
//! Backtesting engine for the trading strategy.
use anyhow::{Context, Result};
use chrono::{DateTime, Duration, Utc};
use std::collections::{BTreeMap, HashMap, HashSet};
use crate::alpaca::{fetch_backtest_data, AlpacaClient};
use crate::config::{
get_all_symbols, IndicatorParams, Timeframe, HOURS_PER_DAY, MAX_POSITION_SIZE,
MIN_CASH_RESERVE, STOP_LOSS_PCT, TAKE_PROFIT_PCT, TOP_MOMENTUM_COUNT,
TRADING_DAYS_PER_YEAR, TRAILING_STOP_ACTIVATION, TRAILING_STOP_DISTANCE,
};
use crate::indicators::{calculate_all_indicators, generate_signal};
use crate::types::{
BacktestPosition, BacktestResult, EquityPoint, IndicatorRow, Signal, Trade,
};
/// Backtesting engine for the trading strategy.
pub struct Backtester {
initial_capital: f64,
cash: f64,
positions: HashMap<String, BacktestPosition>,
trades: Vec<Trade>,
equity_history: Vec<EquityPoint>,
entry_prices: HashMap<String, f64>,
high_water_marks: HashMap<String, f64>,
params: IndicatorParams,
timeframe: Timeframe,
}
impl Backtester {
/// Create a new backtester.
pub fn new(initial_capital: f64, timeframe: Timeframe) -> Self {
Self {
initial_capital,
cash: initial_capital,
positions: HashMap::new(),
trades: Vec::new(),
equity_history: Vec::new(),
entry_prices: HashMap::new(),
high_water_marks: HashMap::new(),
params: timeframe.params(),
timeframe,
}
}
/// Calculate current portfolio value.
fn get_portfolio_value(&self, prices: &HashMap<String, f64>) -> f64 {
let positions_value: f64 = self
.positions
.iter()
.map(|(symbol, pos)| pos.shares * prices.get(symbol).unwrap_or(&pos.entry_price))
.sum();
self.cash + positions_value
}
/// Calculate position size based on risk management.
fn calculate_position_size(&self, price: f64, portfolio_value: f64) -> u64 {
let max_allocation = portfolio_value * MAX_POSITION_SIZE;
let available_cash = self.cash - (portfolio_value * MIN_CASH_RESERVE);
if available_cash <= 0.0 {
return 0;
}
let position_value = max_allocation.min(available_cash);
(position_value / price).floor() as u64
}
/// Execute a simulated buy order.
fn execute_buy(
&mut self,
symbol: &str,
price: f64,
timestamp: DateTime<Utc>,
portfolio_value: f64,
) -> bool {
if self.positions.contains_key(symbol) {
return false;
}
let shares = self.calculate_position_size(price, portfolio_value);
if shares == 0 {
return false;
}
let cost = shares as f64 * price;
if cost > self.cash {
return false;
}
self.cash -= cost;
self.positions.insert(
symbol.to_string(),
BacktestPosition {
symbol: symbol.to_string(),
shares: shares as f64,
entry_price: price,
entry_time: timestamp,
},
);
self.entry_prices.insert(symbol.to_string(), price);
self.high_water_marks.insert(symbol.to_string(), price);
self.trades.push(Trade {
symbol: symbol.to_string(),
side: "BUY".to_string(),
shares: shares as f64,
price,
timestamp,
pnl: 0.0,
pnl_pct: 0.0,
});
true
}
/// Execute a simulated sell order.
fn execute_sell(&mut self, symbol: &str, price: f64, timestamp: DateTime<Utc>) -> bool {
let position = match self.positions.remove(symbol) {
Some(p) => p,
None => return false,
};
let proceeds = position.shares * price;
self.cash += proceeds;
let pnl = proceeds - (position.shares * position.entry_price);
let pnl_pct = (price - position.entry_price) / position.entry_price;
self.trades.push(Trade {
symbol: symbol.to_string(),
side: "SELL".to_string(),
shares: position.shares,
price,
timestamp,
pnl,
pnl_pct,
});
self.entry_prices.remove(symbol);
self.high_water_marks.remove(symbol);
true
}
/// Check if stop-loss, take-profit, or trailing stop should trigger.
fn check_stop_loss_take_profit(&mut self, symbol: &str, current_price: f64) -> Option<Signal> {
let entry_price = match self.entry_prices.get(symbol) {
Some(&p) => p,
None => return None,
};
let pnl_pct = (current_price - entry_price) / entry_price;
// Update high water mark
if let Some(hwm) = self.high_water_marks.get_mut(symbol) {
if current_price > *hwm {
*hwm = current_price;
}
}
// Fixed stop loss
if pnl_pct <= -STOP_LOSS_PCT {
return Some(Signal::StrongSell);
}
// Take profit
if pnl_pct >= TAKE_PROFIT_PCT {
return Some(Signal::Sell);
}
// Trailing stop (only after activation threshold)
if pnl_pct >= TRAILING_STOP_ACTIVATION {
if let Some(&high_water) = self.high_water_marks.get(symbol) {
let trailing_stop_price = high_water * (1.0 - TRAILING_STOP_DISTANCE);
if current_price <= trailing_stop_price {
return Some(Signal::Sell);
}
}
}
None
}
/// Run the backtest simulation.
pub async fn run(&mut self, client: &AlpacaClient, years: f64) -> Result<BacktestResult> {
let symbols = get_all_symbols();
// Calculate warmup period
let warmup_period = self.params.min_bars() + 10;
let warmup_calendar_days = if self.timeframe == Timeframe::Hourly {
(warmup_period as f64 / HOURS_PER_DAY as f64 * 1.5) as i64
} else {
(warmup_period as f64 * 1.5) as i64
};
tracing::info!("{}", "=".repeat(70));
tracing::info!("STARTING BACKTEST");
tracing::info!("Initial Capital: ${:.2}", self.initial_capital);
tracing::info!("Period: {:.2} years ({:.1} months)", years, years * 12.0);
tracing::info!("Timeframe: {:?} bars", self.timeframe);
if self.timeframe == Timeframe::Hourly {
tracing::info!(
"Parameters scaled {}x (e.g., RSI: {}, EMA_TREND: {})",
HOURS_PER_DAY,
self.params.rsi_period,
self.params.ema_trend
);
}
tracing::info!("{}", "=".repeat(70));
// Fetch historical data
let raw_data = fetch_backtest_data(
client,
&symbols.iter().map(|s| *s).collect::<Vec<_>>(),
years,
self.timeframe,
warmup_calendar_days,
)
.await?;
if raw_data.is_empty() {
anyhow::bail!("No historical data available for backtesting");
}
// Calculate indicators for all symbols
let mut data: HashMap<String, Vec<IndicatorRow>> = HashMap::new();
for (symbol, bars) in &raw_data {
let min_bars = self.params.min_bars();
if bars.len() < min_bars {
tracing::warn!(
"{}: Only {} bars, need {}. Skipping.",
symbol,
bars.len(),
min_bars
);
continue;
}
let indicators = calculate_all_indicators(bars, &self.params);
data.insert(symbol.clone(), indicators);
}
// Get common date range
let mut all_dates: BTreeMap<DateTime<Utc>, HashSet<String>> = BTreeMap::new();
for (symbol, rows) in &data {
for row in rows {
all_dates
.entry(row.timestamp)
.or_insert_with(HashSet::new)
.insert(symbol.clone());
}
}
let all_dates: Vec<DateTime<Utc>> = all_dates.keys().copied().collect();
// Calculate trading start date
let end_date = Utc::now();
let trading_start_date = end_date - Duration::days((years * 365.0) as i64);
// Filter to only trade on requested period
let trading_dates: Vec<DateTime<Utc>> = all_dates
.iter()
.filter(|&&d| d >= trading_start_date)
.copied()
.collect();
// Ensure we have enough warmup
let trading_dates = if !trading_dates.is_empty() {
let first_trading_idx = all_dates
.iter()
.position(|&d| d == trading_dates[0])
.unwrap_or(0);
if first_trading_idx < warmup_period {
trading_dates
.into_iter()
.skip(warmup_period - first_trading_idx)
.collect()
} else {
trading_dates
}
} else {
trading_dates
};
if trading_dates.is_empty() {
anyhow::bail!(
"No trading days available after warmup. \
Try a longer backtest period (at least 4 months recommended)."
);
}
tracing::info!(
"\nSimulating {} trading days (after {}-day warmup)...",
trading_dates.len(),
warmup_period
);
// Build index lookup for each symbol's data
let mut symbol_date_index: HashMap<String, HashMap<DateTime<Utc>, usize>> = HashMap::new();
for (symbol, rows) in &data {
let mut idx_map = HashMap::new();
for (i, row) in rows.iter().enumerate() {
idx_map.insert(row.timestamp, i);
}
symbol_date_index.insert(symbol.clone(), idx_map);
}
// Main simulation loop
for (day_num, current_date) in trading_dates.iter().enumerate() {
// Get current prices and momentum for all symbols
let mut current_prices: HashMap<String, f64> = HashMap::new();
let mut momentum_scores: HashMap<String, f64> = HashMap::new();
for (symbol, rows) in &data {
if let Some(&idx) = symbol_date_index.get(symbol).and_then(|m| m.get(current_date)) {
let row = &rows[idx];
current_prices.insert(symbol.clone(), row.close);
if !row.momentum.is_nan() {
momentum_scores.insert(symbol.clone(), row.momentum);
}
}
}
let portfolio_value = self.get_portfolio_value(&current_prices);
// Momentum ranking: sort symbols by momentum
let mut ranked_symbols: Vec<String> = momentum_scores.keys().cloned().collect();
ranked_symbols.sort_by(|a, b| {
let ma = momentum_scores.get(a).unwrap_or(&0.0);
let mb = momentum_scores.get(b).unwrap_or(&0.0);
mb.partial_cmp(ma).unwrap_or(std::cmp::Ordering::Equal)
});
let top_momentum_symbols: HashSet<String> =
ranked_symbols.iter().take(TOP_MOMENTUM_COUNT).cloned().collect();
// Process sells first (for all symbols with positions)
let position_symbols: Vec<String> = self.positions.keys().cloned().collect();
for symbol in position_symbols {
let rows = match data.get(&symbol) {
Some(r) => r,
None => continue,
};
let idx = match symbol_date_index.get(&symbol).and_then(|m| m.get(current_date)) {
Some(&i) => i,
None => continue,
};
if idx < 1 {
continue;
}
let current_row = &rows[idx];
let previous_row = &rows[idx - 1];
if current_row.rsi.is_nan() || current_row.macd.is_nan() {
continue;
}
let mut signal = generate_signal(&symbol, current_row, previous_row);
// Check stop-loss/take-profit/trailing stop
if let Some(sl_tp) = self.check_stop_loss_take_profit(&symbol, signal.current_price)
{
signal.signal = sl_tp;
}
// Execute sells
if signal.signal.is_sell() {
self.execute_sell(&symbol, signal.current_price, *current_date);
}
}
// Process buys (only for top momentum stocks)
for symbol in &ranked_symbols {
let rows = match data.get(symbol) {
Some(r) => r,
None => continue,
};
// Only buy top momentum stocks
if !top_momentum_symbols.contains(symbol) {
continue;
}
let idx = match symbol_date_index.get(symbol).and_then(|m| m.get(current_date)) {
Some(&i) => i,
None => continue,
};
if idx < 1 {
continue;
}
let current_row = &rows[idx];
let previous_row = &rows[idx - 1];
if current_row.rsi.is_nan() || current_row.macd.is_nan() {
continue;
}
let signal = generate_signal(symbol, current_row, previous_row);
// Execute buys
if signal.signal.is_buy() {
self.execute_buy(symbol, signal.current_price, *current_date, portfolio_value);
}
}
// Record equity
self.equity_history.push(EquityPoint {
date: *current_date,
portfolio_value: self.get_portfolio_value(&current_prices),
cash: self.cash,
positions_count: self.positions.len(),
});
// Progress update
if (day_num + 1) % 100 == 0 {
tracing::info!(
" Processed {}/{} days... Portfolio: ${:.2}",
day_num + 1,
trading_dates.len(),
self.equity_history.last().map(|e| e.portfolio_value).unwrap_or(0.0)
);
}
}
// Close all remaining positions at final prices
let final_date = trading_dates.last().copied().unwrap_or_else(Utc::now);
let position_symbols: Vec<String> = self.positions.keys().cloned().collect();
for symbol in position_symbols {
if let Some(rows) = data.get(&symbol) {
if let Some(last_row) = rows.last() {
self.execute_sell(&symbol, last_row.close, final_date);
}
}
}
// Calculate results
let result = self.calculate_results(years)?;
// Print summary
self.print_summary(&result);
Ok(result)
}
/// Calculate performance metrics from backtest.
fn calculate_results(&self, years: f64) -> Result<BacktestResult> {
if self.equity_history.is_empty() {
anyhow::bail!(
"No trading days after indicator warmup period. \
Try a longer backtest period (at least 4 months recommended)."
);
}
let final_value = self.cash;
let total_return = final_value - self.initial_capital;
let total_return_pct = total_return / self.initial_capital;
// CAGR
let cagr = if years > 0.0 {
(final_value / self.initial_capital).powf(1.0 / years) - 1.0
} else {
0.0
};
// Calculate daily returns
let mut daily_returns: Vec<f64> = Vec::new();
for i in 1..self.equity_history.len() {
let prev = self.equity_history[i - 1].portfolio_value;
let curr = self.equity_history[i].portfolio_value;
if prev > 0.0 {
daily_returns.push((curr - prev) / prev);
}
}
// Sharpe Ratio (assuming 252 trading days, risk-free rate ~5%)
let risk_free_daily = 0.05 / TRADING_DAYS_PER_YEAR as f64;
let excess_returns: Vec<f64> = daily_returns.iter().map(|r| r - risk_free_daily).collect();
let sharpe = if !excess_returns.is_empty() {
let mean = excess_returns.iter().sum::<f64>() / excess_returns.len() as f64;
let variance: f64 = excess_returns.iter().map(|r| (r - mean).powi(2)).sum::<f64>()
/ excess_returns.len() as f64;
let std = variance.sqrt();
if std > 0.0 {
(mean / std) * (TRADING_DAYS_PER_YEAR as f64).sqrt()
} else {
0.0
}
} else {
0.0
};
// Sortino Ratio (downside deviation)
let negative_returns: Vec<f64> = daily_returns.iter().filter(|&&r| r < 0.0).copied().collect();
let sortino = if !negative_returns.is_empty() && !daily_returns.is_empty() {
let mean = daily_returns.iter().sum::<f64>() / daily_returns.len() as f64;
let neg_variance: f64 =
negative_returns.iter().map(|r| r.powi(2)).sum::<f64>() / negative_returns.len() as f64;
let neg_std = neg_variance.sqrt();
if neg_std > 0.0 {
(mean / neg_std) * (TRADING_DAYS_PER_YEAR as f64).sqrt()
} else {
0.0
}
} else {
0.0
};
// Maximum Drawdown
let mut max_drawdown = 0.0;
let mut max_drawdown_pct = 0.0;
let mut peak = self.initial_capital;
for point in &self.equity_history {
if point.portfolio_value > peak {
peak = point.portfolio_value;
}
let drawdown = point.portfolio_value - peak;
let drawdown_pct = drawdown / peak;
if drawdown < max_drawdown {
max_drawdown = drawdown;
max_drawdown_pct = drawdown_pct;
}
}
// Trade statistics
let sell_trades: Vec<&Trade> = self.trades.iter().filter(|t| t.side == "SELL").collect();
let winning_trades: Vec<&Trade> = sell_trades.iter().filter(|t| t.pnl > 0.0).copied().collect();
let losing_trades: Vec<&Trade> = sell_trades.iter().filter(|t| t.pnl <= 0.0).copied().collect();
let total_trades = sell_trades.len();
let win_count = winning_trades.len();
let lose_count = losing_trades.len();
let win_rate = if total_trades > 0 {
win_count as f64 / total_trades as f64
} else {
0.0
};
let avg_win = if !winning_trades.is_empty() {
winning_trades.iter().map(|t| t.pnl).sum::<f64>() / winning_trades.len() as f64
} else {
0.0
};
let avg_loss = if !losing_trades.is_empty() {
losing_trades.iter().map(|t| t.pnl).sum::<f64>() / losing_trades.len() as f64
} else {
0.0
};
let total_wins: f64 = winning_trades.iter().map(|t| t.pnl).sum();
let total_losses: f64 = losing_trades.iter().map(|t| t.pnl.abs()).sum();
let profit_factor = if total_losses > 0.0 {
total_wins / total_losses
} else {
f64::INFINITY
};
Ok(BacktestResult {
initial_capital: self.initial_capital,
final_value,
total_return,
total_return_pct,
cagr,
sharpe_ratio: sharpe,
sortino_ratio: sortino,
max_drawdown,
max_drawdown_pct,
total_trades,
winning_trades: win_count,
losing_trades: lose_count,
win_rate,
avg_win,
avg_loss,
profit_factor,
trades: self.trades.clone(),
equity_curve: self.equity_history.clone(),
})
}
/// Print backtest summary.
fn print_summary(&self, result: &BacktestResult) {
println!("\n");
println!("{}", "=".repeat(70));
println!("{:^70}", "BACKTEST RESULTS");
println!("{}", "=".repeat(70));
println!("\n{:^70}", "PORTFOLIO PERFORMANCE");
println!("{}", "-".repeat(70));
println!(" Initial Capital: ${:>15.2}", result.initial_capital);
println!(" Final Value: ${:>15.2}", result.final_value);
println!(
" Total Return: ${:>15.2} ({:>+.2}%)",
result.total_return,
result.total_return_pct * 100.0
);
println!(" CAGR: {:>15.2}%", result.cagr * 100.0);
println!();
println!("{:^70}", "RISK METRICS");
println!("{}", "-".repeat(70));
println!(" Sharpe Ratio: {:>15.2}", result.sharpe_ratio);
println!(" Sortino Ratio: {:>15.2}", result.sortino_ratio);
println!(
" Max Drawdown: ${:>15.2} ({:.2}%)",
result.max_drawdown,
result.max_drawdown_pct * 100.0
);
println!();
println!("{:^70}", "TRADE STATISTICS");
println!("{}", "-".repeat(70));
println!(" Total Trades: {:>15}", result.total_trades);
println!(" Winning Trades: {:>15}", result.winning_trades);
println!(" Losing Trades: {:>15}", result.losing_trades);
println!(" Win Rate: {:>15.2}%", result.win_rate * 100.0);
println!(" Avg Win: ${:>15.2}", result.avg_win);
println!(" Avg Loss: ${:>15.2}", result.avg_loss);
println!(" Profit Factor: {:>15.2}", result.profit_factor);
println!("{}", "=".repeat(70));
// Show recent trades
if !result.trades.is_empty() {
println!("\n{:^70}", "RECENT TRADES (Last 10)");
println!("{}", "-".repeat(70));
println!(
" {:12} {:8} {:6} {:8} {:12} {:12}",
"Date", "Symbol", "Side", "Shares", "Price", "P&L"
);
println!("{}", "-".repeat(70));
for trade in result.trades.iter().rev().take(10).rev() {
let date_str = trade.timestamp.format("%Y-%m-%d").to_string();
let pnl_str = if trade.side == "SELL" {
format!("${:.2}", trade.pnl)
} else {
"-".to_string()
};
println!(
" {:12} {:8} {:6} {:8.0} ${:11.2} {:12}",
date_str, trade.symbol, trade.side, trade.shares, trade.price, pnl_str
);
}
println!("{}", "=".repeat(70));
}
}
}
/// Save backtest results to CSV files.
pub fn save_backtest_results(result: &BacktestResult) -> Result<()> {
// Save equity curve
if !result.equity_curve.is_empty() {
let mut wtr = csv::Writer::from_path("backtest_equity_curve.csv")
.context("Failed to create equity curve CSV")?;
wtr.write_record(["date", "portfolio_value", "cash", "positions_count"])?;
for point in &result.equity_curve {
wtr.write_record(&[
point.date.to_rfc3339(),
point.portfolio_value.to_string(),
point.cash.to_string(),
point.positions_count.to_string(),
])?;
}
wtr.flush()?;
tracing::info!("Equity curve saved to: backtest_equity_curve.csv");
}
// Save trades
if !result.trades.is_empty() {
let mut wtr =
csv::Writer::from_path("backtest_trades.csv").context("Failed to create trades CSV")?;
wtr.write_record(["timestamp", "symbol", "side", "shares", "price", "pnl", "pnl_pct"])?;
for trade in &result.trades {
wtr.write_record(&[
trade.timestamp.to_rfc3339(),
trade.symbol.clone(),
trade.side.clone(),
trade.shares.to_string(),
trade.price.to_string(),
trade.pnl.to_string(),
trade.pnl_pct.to_string(),
])?;
}
wtr.flush()?;
tracing::info!("Trades saved to: backtest_trades.csv");
}
Ok(())
}