478 lines
14 KiB
Rust
478 lines
14 KiB
Rust
//! Alpaca API client for market data and trading.
|
|
|
|
use anyhow::{Context, Result};
|
|
use chrono::{DateTime, Duration, Utc};
|
|
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::Mutex;
|
|
use tokio::time::{sleep, Duration as TokioDuration};
|
|
|
|
use crate::config::Timeframe;
|
|
use crate::types::Bar;
|
|
|
|
const DATA_BASE_URL: &str = "https://data.alpaca.markets/v2";
|
|
const TRADING_BASE_URL: &str = "https://paper-api.alpaca.markets/v2";
|
|
const RATE_LIMIT_REQUESTS_PER_MINUTE: u32 = 200;
|
|
|
|
/// Alpaca API client.
|
|
pub struct AlpacaClient {
|
|
http_client: reqwest::Client,
|
|
api_key: String,
|
|
api_secret: String,
|
|
last_request_time: Arc<Mutex<std::time::Instant>>,
|
|
}
|
|
|
|
// API Response types
|
|
#[derive(Debug, Deserialize)]
|
|
struct BarsResponse {
|
|
bars: HashMap<String, Vec<AlpacaBar>>,
|
|
next_page_token: Option<String>,
|
|
}
|
|
|
|
// Single-symbol bars response (different format from multi-symbol)
|
|
#[derive(Debug, Deserialize)]
|
|
struct SingleBarsResponse {
|
|
bars: Vec<AlpacaBar>,
|
|
symbol: String,
|
|
next_page_token: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct AlpacaBar {
|
|
t: DateTime<Utc>,
|
|
o: f64,
|
|
h: f64,
|
|
l: f64,
|
|
c: f64,
|
|
v: f64,
|
|
vw: Option<f64>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct Account {
|
|
pub id: String,
|
|
pub status: String,
|
|
pub buying_power: String,
|
|
pub portfolio_value: String,
|
|
pub cash: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct Position {
|
|
pub symbol: String,
|
|
pub qty: String,
|
|
pub market_value: String,
|
|
pub avg_entry_price: String,
|
|
pub current_price: String,
|
|
pub unrealized_pl: String,
|
|
pub unrealized_plpc: String,
|
|
pub unrealized_intraday_pl: Option<String>,
|
|
pub change_today: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct Clock {
|
|
pub is_open: bool,
|
|
pub next_open: DateTime<Utc>,
|
|
pub next_close: DateTime<Utc>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct OrderRequest {
|
|
symbol: String,
|
|
qty: String,
|
|
side: String,
|
|
#[serde(rename = "type")]
|
|
order_type: String,
|
|
time_in_force: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct Order {
|
|
pub id: String,
|
|
pub symbol: String,
|
|
pub qty: String,
|
|
pub side: String,
|
|
pub status: String,
|
|
}
|
|
|
|
impl AlpacaClient {
|
|
/// Create a new Alpaca client.
|
|
pub fn new(api_key: String, api_secret: String) -> Result<Self> {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
|
|
|
let http_client = reqwest::Client::builder()
|
|
.default_headers(headers)
|
|
.build()
|
|
.context("Failed to create HTTP client")?;
|
|
|
|
Ok(Self {
|
|
http_client,
|
|
api_key,
|
|
api_secret,
|
|
last_request_time: Arc::new(Mutex::new(std::time::Instant::now())),
|
|
})
|
|
}
|
|
|
|
/// Enforce rate limiting.
|
|
async fn enforce_rate_limit(&self) {
|
|
let min_interval =
|
|
TokioDuration::from_secs_f64(60.0 / RATE_LIMIT_REQUESTS_PER_MINUTE as f64);
|
|
|
|
let mut last_time = self.last_request_time.lock().await;
|
|
let elapsed = last_time.elapsed();
|
|
|
|
if elapsed < min_interval {
|
|
sleep(min_interval - elapsed).await;
|
|
}
|
|
|
|
*last_time = std::time::Instant::now();
|
|
}
|
|
|
|
/// Add authentication headers to a request.
|
|
fn auth_headers(&self) -> HeaderMap {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
"APCA-API-KEY-ID",
|
|
HeaderValue::from_str(&self.api_key).unwrap(),
|
|
);
|
|
headers.insert(
|
|
"APCA-API-SECRET-KEY",
|
|
HeaderValue::from_str(&self.api_secret).unwrap(),
|
|
);
|
|
headers
|
|
}
|
|
|
|
/// Fetch historical bar data for a symbol with pagination support.
|
|
pub async fn get_historical_bars(
|
|
&self,
|
|
symbol: &str,
|
|
timeframe: Timeframe,
|
|
start: DateTime<Utc>,
|
|
end: DateTime<Utc>,
|
|
) -> Result<Vec<Bar>> {
|
|
let tf_str = match timeframe {
|
|
Timeframe::Daily => "1Day",
|
|
Timeframe::Hourly => "1Hour",
|
|
};
|
|
|
|
let mut all_bars = Vec::new();
|
|
let mut page_token: Option<String> = None;
|
|
|
|
loop {
|
|
self.enforce_rate_limit().await;
|
|
|
|
let mut url = format!(
|
|
"{}/stocks/{}/bars?timeframe={}&start={}&end={}&feed=iex&limit=10000",
|
|
DATA_BASE_URL,
|
|
symbol,
|
|
tf_str,
|
|
start.format("%Y-%m-%dT%H:%M:%SZ"),
|
|
end.format("%Y-%m-%dT%H:%M:%SZ"),
|
|
);
|
|
|
|
if let Some(ref token) = page_token {
|
|
url.push_str(&format!("&page_token={}", token));
|
|
}
|
|
|
|
let response = self
|
|
.http_client
|
|
.get(&url)
|
|
.headers(self.auth_headers())
|
|
.send()
|
|
.await
|
|
.context("Failed to fetch bars")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("API error {}: {}", status, text);
|
|
}
|
|
|
|
// Single-symbol endpoint returns a different format
|
|
let data: SingleBarsResponse = response.json().await.context("Failed to parse bars response")?;
|
|
|
|
for b in &data.bars {
|
|
all_bars.push(Bar {
|
|
timestamp: b.t,
|
|
open: b.o,
|
|
high: b.h,
|
|
low: b.l,
|
|
close: b.c,
|
|
volume: b.v,
|
|
vwap: b.vw,
|
|
});
|
|
}
|
|
|
|
// Check for more pages
|
|
if let Some(token) = data.next_page_token {
|
|
if !token.is_empty() {
|
|
page_token = Some(token);
|
|
continue;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
|
|
Ok(all_bars)
|
|
}
|
|
|
|
/// Fetch historical bars for multiple symbols.
|
|
pub async fn get_multi_historical_bars(
|
|
&self,
|
|
symbols: &[&str],
|
|
timeframe: Timeframe,
|
|
start: DateTime<Utc>,
|
|
end: DateTime<Utc>,
|
|
) -> Result<HashMap<String, Vec<Bar>>> {
|
|
self.enforce_rate_limit().await;
|
|
|
|
let tf_str = match timeframe {
|
|
Timeframe::Daily => "1Day",
|
|
Timeframe::Hourly => "1Hour",
|
|
};
|
|
|
|
let symbols_str = symbols.join(",");
|
|
let url = format!(
|
|
"{}/stocks/bars?symbols={}&timeframe={}&start={}&end={}&feed=iex&limit=10000",
|
|
DATA_BASE_URL,
|
|
symbols_str,
|
|
tf_str,
|
|
start.format("%Y-%m-%dT%H:%M:%SZ"),
|
|
end.format("%Y-%m-%dT%H:%M:%SZ"),
|
|
);
|
|
|
|
let response = self
|
|
.http_client
|
|
.get(&url)
|
|
.headers(self.auth_headers())
|
|
.send()
|
|
.await
|
|
.context("Failed to fetch bars")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("API error {}: {}", status, text);
|
|
}
|
|
|
|
let data: BarsResponse = response.json().await.context("Failed to parse bars response")?;
|
|
|
|
let mut result = HashMap::new();
|
|
for (symbol, bars) in data.bars {
|
|
let converted: Vec<Bar> = bars
|
|
.iter()
|
|
.map(|b| Bar {
|
|
timestamp: b.t,
|
|
open: b.o,
|
|
high: b.h,
|
|
low: b.l,
|
|
close: b.c,
|
|
volume: b.v,
|
|
vwap: b.vw,
|
|
})
|
|
.collect();
|
|
result.insert(symbol, converted);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Get account information.
|
|
pub async fn get_account(&self) -> Result<Account> {
|
|
self.enforce_rate_limit().await;
|
|
|
|
let url = format!("{}/account", TRADING_BASE_URL);
|
|
|
|
let response = self
|
|
.http_client
|
|
.get(&url)
|
|
.headers(self.auth_headers())
|
|
.send()
|
|
.await
|
|
.context("Failed to get account")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("API error {}: {}", status, text);
|
|
}
|
|
|
|
response.json().await.context("Failed to parse account")
|
|
}
|
|
|
|
/// Get all positions.
|
|
pub async fn get_positions(&self) -> Result<Vec<Position>> {
|
|
self.enforce_rate_limit().await;
|
|
|
|
let url = format!("{}/positions", TRADING_BASE_URL);
|
|
|
|
let response = self
|
|
.http_client
|
|
.get(&url)
|
|
.headers(self.auth_headers())
|
|
.send()
|
|
.await
|
|
.context("Failed to get positions")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("API error {}: {}", status, text);
|
|
}
|
|
|
|
response.json().await.context("Failed to parse positions")
|
|
}
|
|
|
|
/// Get position for a specific symbol.
|
|
pub async fn get_position(&self, symbol: &str) -> Result<Option<Position>> {
|
|
self.enforce_rate_limit().await;
|
|
|
|
let url = format!("{}/positions/{}", TRADING_BASE_URL, symbol);
|
|
|
|
let response = self
|
|
.http_client
|
|
.get(&url)
|
|
.headers(self.auth_headers())
|
|
.send()
|
|
.await
|
|
.context("Failed to get position")?;
|
|
|
|
if response.status().as_u16() == 404 {
|
|
return Ok(None);
|
|
}
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("API error {}: {}", status, text);
|
|
}
|
|
|
|
let position: Position = response.json().await.context("Failed to parse position")?;
|
|
Ok(Some(position))
|
|
}
|
|
|
|
/// Get market clock.
|
|
pub async fn get_clock(&self) -> Result<Clock> {
|
|
self.enforce_rate_limit().await;
|
|
|
|
let url = format!("{}/clock", TRADING_BASE_URL);
|
|
|
|
let response = self
|
|
.http_client
|
|
.get(&url)
|
|
.headers(self.auth_headers())
|
|
.send()
|
|
.await
|
|
.context("Failed to get clock")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("API error {}: {}", status, text);
|
|
}
|
|
|
|
response.json().await.context("Failed to parse clock")
|
|
}
|
|
|
|
/// Submit a market order.
|
|
pub async fn submit_market_order(
|
|
&self,
|
|
symbol: &str,
|
|
qty: f64,
|
|
side: &str,
|
|
) -> Result<Order> {
|
|
self.enforce_rate_limit().await;
|
|
|
|
let url = format!("{}/orders", TRADING_BASE_URL);
|
|
|
|
let order_request = OrderRequest {
|
|
symbol: symbol.to_string(),
|
|
qty: qty.to_string(),
|
|
side: side.to_string(),
|
|
order_type: "market".to_string(),
|
|
time_in_force: "day".to_string(),
|
|
};
|
|
|
|
let response = self
|
|
.http_client
|
|
.post(&url)
|
|
.headers(self.auth_headers())
|
|
.json(&order_request)
|
|
.send()
|
|
.await
|
|
.context("Failed to submit order")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("API error {}: {}", status, text);
|
|
}
|
|
|
|
response.json().await.context("Failed to parse order")
|
|
}
|
|
|
|
/// Check if market is open.
|
|
pub async fn is_market_open(&self) -> Result<bool> {
|
|
let clock = self.get_clock().await?;
|
|
Ok(clock.is_open)
|
|
}
|
|
|
|
/// Get next market open time.
|
|
pub async fn get_next_market_open(&self) -> Result<DateTime<Utc>> {
|
|
let clock = self.get_clock().await?;
|
|
Ok(clock.next_open)
|
|
}
|
|
}
|
|
|
|
/// Helper to fetch bars for backtesting with proper date handling.
|
|
/// Fetches each symbol individually to avoid API limits on multi-symbol requests.
|
|
pub async fn fetch_backtest_data(
|
|
client: &AlpacaClient,
|
|
symbols: &[&str],
|
|
years: f64,
|
|
timeframe: Timeframe,
|
|
warmup_days: i64,
|
|
) -> Result<HashMap<String, Vec<Bar>>> {
|
|
let end = Utc::now();
|
|
let days = (years * 365.0) as i64 + warmup_days + 30;
|
|
let start = end - Duration::days(days);
|
|
|
|
tracing::info!(
|
|
"Fetching {:.2} years of data ({} to {})...",
|
|
years,
|
|
start.format("%Y-%m-%d"),
|
|
end.format("%Y-%m-%d")
|
|
);
|
|
|
|
let mut all_data = HashMap::new();
|
|
|
|
// Fetch each symbol individually like Python does
|
|
// The multi-symbol endpoint has a 10000 bar limit across ALL symbols
|
|
for symbol in symbols {
|
|
tracing::info!(" Fetching {}...", symbol);
|
|
|
|
match client
|
|
.get_historical_bars(symbol, timeframe, start, end)
|
|
.await
|
|
{
|
|
Ok(bars) => {
|
|
if !bars.is_empty() {
|
|
tracing::info!(" {}: {} bars loaded", symbol, bars.len());
|
|
all_data.insert(symbol.to_string(), bars);
|
|
} else {
|
|
tracing::warn!(" {}: No data", symbol);
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(" Failed to fetch {}: {}", symbol, e);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(all_data)
|
|
}
|