595 lines
22 KiB
Rust
595 lines
22 KiB
Rust
//! Live trading bot using Alpaca API.
|
|
|
|
use anyhow::Result;
|
|
use chrono::{Duration, Utc};
|
|
use std::collections::HashMap;
|
|
use tokio::time::{sleep, Duration as TokioDuration};
|
|
|
|
use crate::alpaca::AlpacaClient;
|
|
use crate::config::{
|
|
get_all_symbols, IndicatorParams, Timeframe, BOT_CHECK_INTERVAL_SECONDS, HOURS_PER_DAY,
|
|
MAX_POSITION_SIZE, MIN_CASH_RESERVE, STOP_LOSS_PCT, TAKE_PROFIT_PCT,
|
|
TOP_MOMENTUM_COUNT, TRAILING_STOP_ACTIVATION, TRAILING_STOP_DISTANCE,
|
|
};
|
|
use crate::indicators::{calculate_all_indicators, generate_signal};
|
|
use crate::paths::{LIVE_EQUITY_FILE, LIVE_HIGH_WATER_MARKS_FILE, LIVE_POSITIONS_FILE};
|
|
use crate::types::{EquitySnapshot, PositionInfo, Signal, TradeSignal};
|
|
|
|
/// Live trading bot for paper trading.
|
|
pub struct TradingBot {
|
|
client: AlpacaClient,
|
|
params: IndicatorParams,
|
|
timeframe: Timeframe,
|
|
entry_prices: HashMap<String, f64>,
|
|
high_water_marks: HashMap<String, f64>,
|
|
equity_history: Vec<EquitySnapshot>,
|
|
}
|
|
|
|
impl TradingBot {
|
|
/// Create a new trading bot.
|
|
pub async fn new(
|
|
api_key: String,
|
|
api_secret: String,
|
|
timeframe: Timeframe,
|
|
) -> Result<Self> {
|
|
let client = AlpacaClient::new(api_key, api_secret)?;
|
|
|
|
let mut bot = Self {
|
|
client,
|
|
params: timeframe.params(),
|
|
timeframe,
|
|
entry_prices: HashMap::new(),
|
|
high_water_marks: HashMap::new(),
|
|
equity_history: Vec::new(),
|
|
};
|
|
|
|
// Load persisted state
|
|
bot.load_entry_prices();
|
|
bot.load_high_water_marks();
|
|
bot.load_equity_history();
|
|
|
|
// Log account info
|
|
bot.log_account_info().await;
|
|
|
|
tracing::info!("Trading bot initialized successfully (Paper Trading Mode)");
|
|
|
|
Ok(bot)
|
|
}
|
|
|
|
/// Load entry prices from file.
|
|
fn load_entry_prices(&mut self) {
|
|
if LIVE_POSITIONS_FILE.exists() {
|
|
match std::fs::read_to_string(&*LIVE_POSITIONS_FILE) {
|
|
Ok(content) => {
|
|
if !content.is_empty() {
|
|
match serde_json::from_str::<HashMap<String, f64>>(&content) {
|
|
Ok(prices) => {
|
|
tracing::info!("Loaded entry prices for {} positions.", prices.len());
|
|
self.entry_prices = prices;
|
|
}
|
|
Err(e) => tracing::error!("Error parsing positions file: {}", e),
|
|
}
|
|
}
|
|
}
|
|
Err(e) => tracing::error!("Error loading positions file: {}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Save entry prices to file.
|
|
fn save_entry_prices(&self) {
|
|
match serde_json::to_string_pretty(&self.entry_prices) {
|
|
Ok(json) => {
|
|
if let Err(e) = std::fs::write(&*LIVE_POSITIONS_FILE, json) {
|
|
tracing::error!("Error saving positions file: {}", e);
|
|
}
|
|
}
|
|
Err(e) => tracing::error!("Error serializing positions: {}", e),
|
|
}
|
|
}
|
|
|
|
/// Load high water marks from file.
|
|
fn load_high_water_marks(&mut self) {
|
|
if LIVE_HIGH_WATER_MARKS_FILE.exists() {
|
|
match std::fs::read_to_string(&*LIVE_HIGH_WATER_MARKS_FILE) {
|
|
Ok(content) => {
|
|
if !content.is_empty() {
|
|
match serde_json::from_str::<HashMap<String, f64>>(&content) {
|
|
Ok(marks) => {
|
|
tracing::info!("Loaded high water marks for {} positions.", marks.len());
|
|
self.high_water_marks = marks;
|
|
}
|
|
Err(e) => tracing::error!("Error parsing high water marks file: {}", e),
|
|
}
|
|
}
|
|
}
|
|
Err(e) => tracing::error!("Error loading high water marks file: {}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Save high water marks to file.
|
|
fn save_high_water_marks(&self) {
|
|
match serde_json::to_string_pretty(&self.high_water_marks) {
|
|
Ok(json) => {
|
|
if let Err(e) = std::fs::write(&*LIVE_HIGH_WATER_MARKS_FILE, json) {
|
|
tracing::error!("Error saving high water marks file: {}", e);
|
|
}
|
|
}
|
|
Err(e) => tracing::error!("Error serializing high water marks: {}", e),
|
|
}
|
|
}
|
|
|
|
/// Load equity history from file.
|
|
fn load_equity_history(&mut self) {
|
|
if LIVE_EQUITY_FILE.exists() {
|
|
match std::fs::read_to_string(&*LIVE_EQUITY_FILE) {
|
|
Ok(content) => {
|
|
if !content.is_empty() {
|
|
match serde_json::from_str::<Vec<EquitySnapshot>>(&content) {
|
|
Ok(history) => {
|
|
tracing::info!("Loaded {} equity data points.", history.len());
|
|
self.equity_history = history;
|
|
}
|
|
Err(e) => tracing::error!("Error parsing equity history: {}", e),
|
|
}
|
|
}
|
|
}
|
|
Err(e) => tracing::error!("Error loading equity history: {}", e),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Save equity snapshot.
|
|
async fn save_equity_snapshot(&mut self) -> Result<()> {
|
|
let account = self.client.get_account().await?;
|
|
let positions = self.client.get_positions().await?;
|
|
|
|
let mut positions_map = HashMap::new();
|
|
for pos in &positions {
|
|
positions_map.insert(
|
|
pos.symbol.clone(),
|
|
PositionInfo {
|
|
qty: pos.qty.parse().unwrap_or(0.0),
|
|
market_value: pos.market_value.parse().unwrap_or(0.0),
|
|
avg_entry_price: pos.avg_entry_price.parse().unwrap_or(0.0),
|
|
current_price: pos.current_price.parse().unwrap_or(0.0),
|
|
unrealized_pnl: pos.unrealized_pl.parse().unwrap_or(0.0),
|
|
pnl_pct: pos.unrealized_plpc.parse::<f64>().unwrap_or(0.0) * 100.0,
|
|
change_today: pos.change_today.as_ref().and_then(|s| s.parse::<f64>().ok()).unwrap_or(0.0) * 100.0,
|
|
},
|
|
);
|
|
}
|
|
|
|
let snapshot = EquitySnapshot {
|
|
timestamp: Utc::now().to_rfc3339(),
|
|
portfolio_value: account.portfolio_value.parse().unwrap_or(0.0),
|
|
cash: account.cash.parse().unwrap_or(0.0),
|
|
buying_power: account.buying_power.parse().unwrap_or(0.0),
|
|
positions_count: positions.len(),
|
|
positions: positions_map,
|
|
};
|
|
|
|
self.equity_history.push(snapshot.clone());
|
|
|
|
// Keep last 7 trading days of equity data (4 snapshots per minute at 15s intervals).
|
|
const SNAPSHOTS_PER_MINUTE: usize = 4;
|
|
const MINUTES_PER_HOUR: usize = 60;
|
|
const DAYS_TO_KEEP: usize = 7;
|
|
const MAX_SNAPSHOTS: usize = DAYS_TO_KEEP * HOURS_PER_DAY * MINUTES_PER_HOUR * SNAPSHOTS_PER_MINUTE;
|
|
|
|
if self.equity_history.len() > MAX_SNAPSHOTS {
|
|
let start = self.equity_history.len() - MAX_SNAPSHOTS;
|
|
self.equity_history = self.equity_history[start..].to_vec();
|
|
}
|
|
|
|
// Save to file
|
|
match serde_json::to_string_pretty(&self.equity_history) {
|
|
Ok(json) => {
|
|
if let Err(e) = std::fs::write(&*LIVE_EQUITY_FILE, json) {
|
|
tracing::error!("Error saving equity history: {}", e);
|
|
}
|
|
}
|
|
Err(e) => tracing::error!("Error serializing equity history: {}", e),
|
|
}
|
|
|
|
tracing::info!("Saved equity snapshot: ${:.2}", snapshot.portfolio_value);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Log current account information.
|
|
async fn log_account_info(&self) {
|
|
match self.client.get_account().await {
|
|
Ok(account) => {
|
|
let portfolio_value: f64 = account.portfolio_value.parse().unwrap_or(0.0);
|
|
let buying_power: f64 = account.buying_power.parse().unwrap_or(0.0);
|
|
let cash: f64 = account.cash.parse().unwrap_or(0.0);
|
|
|
|
tracing::info!("Account Status: {}", account.status);
|
|
tracing::info!("Buying Power: ${:.2}", buying_power);
|
|
tracing::info!("Portfolio Value: ${:.2}", portfolio_value);
|
|
tracing::info!("Cash: ${:.2}", cash);
|
|
}
|
|
Err(e) => tracing::error!("Failed to get account info: {}", e),
|
|
}
|
|
}
|
|
|
|
/// Get position quantity for a symbol.
|
|
async fn get_position(&self, symbol: &str) -> Option<f64> {
|
|
match self.client.get_position(symbol).await {
|
|
Ok(Some(pos)) => pos.qty.parse().ok(),
|
|
Ok(None) => None,
|
|
Err(e) => {
|
|
tracing::error!("Failed to get position for {}: {}", symbol, e);
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Calculate position size based on risk management.
|
|
async fn calculate_position_size(&self, price: f64) -> u64 {
|
|
let account = match self.client.get_account().await {
|
|
Ok(a) => a,
|
|
Err(e) => {
|
|
tracing::error!("Failed to get account: {}", e);
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
let portfolio_value: f64 = account.portfolio_value.parse().unwrap_or(0.0);
|
|
let cash: f64 = account.cash.parse().unwrap_or(0.0);
|
|
|
|
let max_allocation = portfolio_value * MAX_POSITION_SIZE;
|
|
let available_funds = cash - (portfolio_value * MIN_CASH_RESERVE);
|
|
|
|
if available_funds <= 0.0 {
|
|
return 0;
|
|
}
|
|
|
|
let position_value = max_allocation.min(available_funds);
|
|
(position_value / price).floor() as u64
|
|
}
|
|
|
|
/// Check if stop-loss, take-profit, or trailing stop should trigger.
|
|
fn check_stop_loss_take_profit(&mut self, symbol: &str, current_price: f64) -> Option<Signal> {
|
|
let entry_price = match self.entry_prices.get(symbol) {
|
|
Some(&p) => p,
|
|
None => return None,
|
|
};
|
|
|
|
let pnl_pct = (current_price - entry_price) / entry_price;
|
|
|
|
// Update high water mark
|
|
if let Some(hwm) = self.high_water_marks.get_mut(symbol) {
|
|
if current_price > *hwm {
|
|
*hwm = current_price;
|
|
self.save_high_water_marks();
|
|
}
|
|
}
|
|
|
|
// Fixed stop loss
|
|
if pnl_pct <= -STOP_LOSS_PCT {
|
|
tracing::warn!("{}: Stop-loss triggered at {:.2}% loss", symbol, pnl_pct * 100.0);
|
|
return Some(Signal::StrongSell);
|
|
}
|
|
|
|
// Take profit
|
|
if pnl_pct >= TAKE_PROFIT_PCT {
|
|
tracing::info!("{}: Take-profit triggered at {:.2}% gain", symbol, pnl_pct * 100.0);
|
|
return Some(Signal::Sell);
|
|
}
|
|
|
|
// Trailing stop (only after activation threshold)
|
|
if pnl_pct >= TRAILING_STOP_ACTIVATION {
|
|
if let Some(&high_water) = self.high_water_marks.get(symbol) {
|
|
let trailing_stop_price = high_water * (1.0 - TRAILING_STOP_DISTANCE);
|
|
if current_price <= trailing_stop_price {
|
|
tracing::info!(
|
|
"{}: Trailing stop triggered at ${:.2} (peak: ${:.2}, stop: ${:.2})",
|
|
symbol, current_price, high_water, trailing_stop_price
|
|
);
|
|
return Some(Signal::Sell);
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
/// Execute a buy order.
|
|
async fn execute_buy(&mut self, symbol: &str, signal: &TradeSignal) -> bool {
|
|
// Check if already holding
|
|
if let Some(qty) = self.get_position(symbol).await {
|
|
if qty > 0.0 {
|
|
tracing::info!("{}: Already holding {} shares, skipping buy", symbol, qty);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
let shares = self.calculate_position_size(signal.current_price).await;
|
|
if shares == 0 {
|
|
tracing::info!("{}: Insufficient funds for purchase", symbol);
|
|
return false;
|
|
}
|
|
|
|
match self
|
|
.client
|
|
.submit_market_order(symbol, shares as f64, "buy")
|
|
.await
|
|
{
|
|
Ok(_order) => {
|
|
self.entry_prices.insert(symbol.to_string(), signal.current_price);
|
|
self.high_water_marks.insert(symbol.to_string(), signal.current_price);
|
|
self.save_entry_prices();
|
|
self.save_high_water_marks();
|
|
|
|
tracing::info!(
|
|
"BUY ORDER EXECUTED: {} - {} shares @ ~${:.2} \
|
|
(RSI: {:.1}, MACD: {:.3}, Confidence: {:.2})",
|
|
symbol,
|
|
shares,
|
|
signal.current_price,
|
|
signal.rsi,
|
|
signal.macd_histogram,
|
|
signal.confidence
|
|
);
|
|
|
|
true
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Failed to execute buy for {}: {}", symbol, e);
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Execute a sell order.
|
|
async fn execute_sell(&mut self, symbol: &str, signal: &TradeSignal) -> bool {
|
|
let current_position = match self.get_position(symbol).await {
|
|
Some(qty) if qty > 0.0 => qty,
|
|
_ => {
|
|
tracing::info!("{}: No position to sell", symbol);
|
|
return false;
|
|
}
|
|
};
|
|
|
|
match self
|
|
.client
|
|
.submit_market_order(symbol, current_position, "sell")
|
|
.await
|
|
{
|
|
Ok(_order) => {
|
|
if let Some(entry) = self.entry_prices.remove(symbol) {
|
|
let pnl_pct = (signal.current_price - entry) / entry;
|
|
tracing::info!("{}: Realized P&L: {:.2}%", symbol, pnl_pct * 100.0);
|
|
self.save_entry_prices();
|
|
}
|
|
self.high_water_marks.remove(symbol);
|
|
self.save_high_water_marks();
|
|
|
|
tracing::info!(
|
|
"SELL ORDER EXECUTED: {} - {} shares @ ~${:.2} \
|
|
(RSI: {:.1}, MACD: {:.3})",
|
|
symbol,
|
|
current_position,
|
|
signal.current_price,
|
|
signal.rsi,
|
|
signal.macd_histogram
|
|
);
|
|
|
|
true
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Failed to execute sell for {}: {}", symbol, e);
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Analyze a symbol and generate trading signal (without stop-loss check).
|
|
async fn analyze_symbol(&self, symbol: &str) -> Option<TradeSignal> {
|
|
let min_bars = self.params.min_bars();
|
|
|
|
// Calculate days needed for data
|
|
let days = if self.timeframe == Timeframe::Hourly {
|
|
(min_bars as f64 / HOURS_PER_DAY as f64 * 1.5) as i64 + 10
|
|
} else {
|
|
(min_bars as f64 * 1.5) as i64 + 30
|
|
};
|
|
|
|
let end = Utc::now();
|
|
let start = end - Duration::days(days);
|
|
|
|
let bars = match self.client.get_historical_bars(symbol, self.timeframe, start, end).await {
|
|
Ok(b) => b,
|
|
Err(e) => {
|
|
tracing::warn!("{}: Failed to get historical data: {}", symbol, e);
|
|
return None;
|
|
}
|
|
};
|
|
|
|
if bars.len() < min_bars {
|
|
tracing::warn!(
|
|
"{}: Only {} bars, need {} for indicators",
|
|
symbol,
|
|
bars.len(),
|
|
min_bars
|
|
);
|
|
return None;
|
|
}
|
|
|
|
let indicators = calculate_all_indicators(&bars, &self.params);
|
|
|
|
if indicators.len() < 2 {
|
|
return None;
|
|
}
|
|
|
|
let current = &indicators[indicators.len() - 1];
|
|
let previous = &indicators[indicators.len() - 2];
|
|
|
|
if current.rsi.is_nan() || current.macd.is_nan() {
|
|
return None;
|
|
}
|
|
|
|
Some(generate_signal(symbol, current, previous))
|
|
}
|
|
|
|
/// Execute one complete trading cycle.
|
|
async fn run_trading_cycle(&mut self) {
|
|
tracing::info!("{}", "=".repeat(60));
|
|
tracing::info!("Starting trading cycle...");
|
|
self.log_account_info().await;
|
|
|
|
let symbols = get_all_symbols();
|
|
|
|
// Analyze all symbols first
|
|
let mut signals: Vec<TradeSignal> = Vec::new();
|
|
for symbol in &symbols {
|
|
tracing::info!("\nAnalyzing {}...", symbol);
|
|
|
|
let signal = match self.analyze_symbol(symbol).await {
|
|
Some(s) => s,
|
|
None => {
|
|
tracing::warn!("{}: Analysis failed, skipping", symbol);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
tracing::info!(
|
|
"{}: Signal={}, RSI={:.1}, MACD Hist={:.3}, Momentum={:.2}%, \
|
|
Price=${:.2}, Confidence={:.2}",
|
|
signal.symbol,
|
|
signal.signal.as_str(),
|
|
signal.rsi,
|
|
signal.macd_histogram,
|
|
signal.momentum,
|
|
signal.current_price,
|
|
signal.confidence
|
|
);
|
|
|
|
signals.push(signal);
|
|
|
|
// Small delay between symbols for rate limiting
|
|
sleep(TokioDuration::from_millis(500)).await;
|
|
}
|
|
|
|
// Phase 1: Process all sells first (free up cash before buying)
|
|
for signal in &signals {
|
|
let mut effective_signal = signal.clone();
|
|
|
|
// Check stop-loss/take-profit/trailing stop
|
|
if let Some(sl_tp) = self.check_stop_loss_take_profit(&signal.symbol, signal.current_price) {
|
|
effective_signal.signal = sl_tp;
|
|
}
|
|
|
|
if effective_signal.signal.is_sell() {
|
|
self.execute_sell(&signal.symbol, &effective_signal).await;
|
|
}
|
|
}
|
|
|
|
// Phase 2: Momentum ranking - only buy top N momentum stocks
|
|
let mut ranked_signals: Vec<&TradeSignal> = signals
|
|
.iter()
|
|
.filter(|s| !s.momentum.is_nan())
|
|
.collect();
|
|
ranked_signals.sort_by(|a, b| {
|
|
b.momentum.partial_cmp(&a.momentum).unwrap_or(std::cmp::Ordering::Equal)
|
|
});
|
|
|
|
let top_momentum_symbols: std::collections::HashSet<String> = ranked_signals
|
|
.iter()
|
|
.take(TOP_MOMENTUM_COUNT)
|
|
.map(|s| s.symbol.clone())
|
|
.collect();
|
|
|
|
tracing::info!(
|
|
"Top {} momentum stocks: {:?}",
|
|
TOP_MOMENTUM_COUNT,
|
|
top_momentum_symbols
|
|
);
|
|
|
|
// Phase 3: Process buys in momentum-ranked order (highest momentum first)
|
|
for signal in &ranked_signals {
|
|
if !top_momentum_symbols.contains(&signal.symbol) {
|
|
continue;
|
|
}
|
|
|
|
if signal.signal.is_buy() {
|
|
self.execute_buy(&signal.symbol, signal).await;
|
|
}
|
|
}
|
|
|
|
// Save equity snapshot for dashboard
|
|
if let Err(e) = self.save_equity_snapshot().await {
|
|
tracing::error!("Failed to save equity snapshot: {}", e);
|
|
}
|
|
|
|
tracing::info!("Trading cycle complete");
|
|
tracing::info!("{}", "=".repeat(60));
|
|
}
|
|
|
|
/// Main bot loop - runs continuously during market hours.
|
|
pub async fn run(&mut self) -> Result<()> {
|
|
let symbols = get_all_symbols();
|
|
|
|
tracing::info!("{}", "=".repeat(60));
|
|
tracing::info!("TECH GIANTS TRADING BOT STARTED");
|
|
tracing::info!("Timeframe: {:?} bars", self.timeframe);
|
|
if self.timeframe == Timeframe::Hourly {
|
|
tracing::info!(
|
|
"Parameters scaled {}x (RSI: {}, EMA_TREND: {})",
|
|
HOURS_PER_DAY,
|
|
self.params.rsi_period,
|
|
self.params.ema_trend
|
|
);
|
|
}
|
|
tracing::info!("Symbols: {}", symbols.join(", "));
|
|
tracing::info!(
|
|
"Strategy: RSI({}) + MACD({},{},{}) + Momentum",
|
|
self.params.rsi_period,
|
|
self.params.macd_fast,
|
|
self.params.macd_slow,
|
|
self.params.macd_signal
|
|
);
|
|
tracing::info!("Bot Check Interval: {} seconds", BOT_CHECK_INTERVAL_SECONDS);
|
|
tracing::info!("{}", "=".repeat(60));
|
|
|
|
loop {
|
|
match self.client.is_market_open().await {
|
|
Ok(true) => {
|
|
self.run_trading_cycle().await;
|
|
|
|
tracing::info!(
|
|
"Next signal check in {} seconds...",
|
|
BOT_CHECK_INTERVAL_SECONDS
|
|
);
|
|
sleep(TokioDuration::from_secs(BOT_CHECK_INTERVAL_SECONDS)).await;
|
|
}
|
|
Ok(false) => {
|
|
match self.client.get_next_market_open().await {
|
|
Ok(next_open) => {
|
|
let wait_seconds = (next_open - Utc::now()).num_seconds().max(0);
|
|
tracing::info!("Market closed. Next open: {}", next_open);
|
|
tracing::info!("Waiting {:.1} hours...", wait_seconds as f64 / 3600.0);
|
|
|
|
let sleep_time = (wait_seconds as u64).min(300).max(60);
|
|
sleep(TokioDuration::from_secs(sleep_time)).await;
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Failed to get next market open: {}", e);
|
|
tracing::info!("Market closed. Checking again in 5 minutes...");
|
|
sleep(TokioDuration::from_secs(300)).await;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Failed to check market status: {}", e);
|
|
tracing::info!("Retrying in 60 seconds...");
|
|
sleep(TokioDuration::from_secs(60)).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|