first comit
This commit is contained in:
476
src/alpaca.rs
Normal file
476
src/alpaca.rs
Normal file
@@ -0,0 +1,476 @@
|
||||
//! Alpaca API client for market data and trading.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{sleep, Duration as TokioDuration};
|
||||
|
||||
use crate::config::Timeframe;
|
||||
use crate::types::Bar;
|
||||
|
||||
const DATA_BASE_URL: &str = "https://data.alpaca.markets/v2";
|
||||
const TRADING_BASE_URL: &str = "https://paper-api.alpaca.markets/v2";
|
||||
const RATE_LIMIT_REQUESTS_PER_MINUTE: u32 = 200;
|
||||
|
||||
/// Alpaca API client.
|
||||
pub struct AlpacaClient {
|
||||
http_client: reqwest::Client,
|
||||
api_key: String,
|
||||
api_secret: String,
|
||||
last_request_time: Arc<Mutex<std::time::Instant>>,
|
||||
}
|
||||
|
||||
// API Response types
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BarsResponse {
|
||||
bars: HashMap<String, Vec<AlpacaBar>>,
|
||||
next_page_token: Option<String>,
|
||||
}
|
||||
|
||||
// Single-symbol bars response (different format from multi-symbol)
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SingleBarsResponse {
|
||||
bars: Vec<AlpacaBar>,
|
||||
symbol: String,
|
||||
next_page_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AlpacaBar {
|
||||
t: DateTime<Utc>,
|
||||
o: f64,
|
||||
h: f64,
|
||||
l: f64,
|
||||
c: f64,
|
||||
v: f64,
|
||||
vw: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Account {
|
||||
pub id: String,
|
||||
pub status: String,
|
||||
pub buying_power: String,
|
||||
pub portfolio_value: String,
|
||||
pub cash: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Position {
|
||||
pub symbol: String,
|
||||
pub qty: String,
|
||||
pub market_value: String,
|
||||
pub avg_entry_price: String,
|
||||
pub current_price: String,
|
||||
pub unrealized_pl: String,
|
||||
pub unrealized_plpc: String,
|
||||
pub change_today: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Clock {
|
||||
pub is_open: bool,
|
||||
pub next_open: DateTime<Utc>,
|
||||
pub next_close: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OrderRequest {
|
||||
symbol: String,
|
||||
qty: String,
|
||||
side: String,
|
||||
#[serde(rename = "type")]
|
||||
order_type: String,
|
||||
time_in_force: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Order {
|
||||
pub id: String,
|
||||
pub symbol: String,
|
||||
pub qty: String,
|
||||
pub side: String,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
impl AlpacaClient {
|
||||
/// Create a new Alpaca client.
|
||||
pub fn new(api_key: String, api_secret: String) -> Result<Self> {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.context("Failed to create HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
http_client,
|
||||
api_key,
|
||||
api_secret,
|
||||
last_request_time: Arc::new(Mutex::new(std::time::Instant::now())),
|
||||
})
|
||||
}
|
||||
|
||||
/// Enforce rate limiting.
|
||||
async fn enforce_rate_limit(&self) {
|
||||
let min_interval =
|
||||
TokioDuration::from_secs_f64(60.0 / RATE_LIMIT_REQUESTS_PER_MINUTE as f64);
|
||||
|
||||
let mut last_time = self.last_request_time.lock().await;
|
||||
let elapsed = last_time.elapsed();
|
||||
|
||||
if elapsed < min_interval {
|
||||
sleep(min_interval - elapsed).await;
|
||||
}
|
||||
|
||||
*last_time = std::time::Instant::now();
|
||||
}
|
||||
|
||||
/// Add authentication headers to a request.
|
||||
fn auth_headers(&self) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"APCA-API-KEY-ID",
|
||||
HeaderValue::from_str(&self.api_key).unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"APCA-API-SECRET-KEY",
|
||||
HeaderValue::from_str(&self.api_secret).unwrap(),
|
||||
);
|
||||
headers
|
||||
}
|
||||
|
||||
/// Fetch historical bar data for a symbol with pagination support.
|
||||
pub async fn get_historical_bars(
|
||||
&self,
|
||||
symbol: &str,
|
||||
timeframe: Timeframe,
|
||||
start: DateTime<Utc>,
|
||||
end: DateTime<Utc>,
|
||||
) -> Result<Vec<Bar>> {
|
||||
let tf_str = match timeframe {
|
||||
Timeframe::Daily => "1Day",
|
||||
Timeframe::Hourly => "1Hour",
|
||||
};
|
||||
|
||||
let mut all_bars = Vec::new();
|
||||
let mut page_token: Option<String> = None;
|
||||
|
||||
loop {
|
||||
self.enforce_rate_limit().await;
|
||||
|
||||
let mut url = format!(
|
||||
"{}/stocks/{}/bars?timeframe={}&start={}&end={}&feed=iex&limit=10000",
|
||||
DATA_BASE_URL,
|
||||
symbol,
|
||||
tf_str,
|
||||
start.format("%Y-%m-%dT%H:%M:%SZ"),
|
||||
end.format("%Y-%m-%dT%H:%M:%SZ"),
|
||||
);
|
||||
|
||||
if let Some(ref token) = page_token {
|
||||
url.push_str(&format!("&page_token={}", token));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.headers(self.auth_headers())
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to fetch bars")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("API error {}: {}", status, text);
|
||||
}
|
||||
|
||||
// Single-symbol endpoint returns a different format
|
||||
let data: SingleBarsResponse = response.json().await.context("Failed to parse bars response")?;
|
||||
|
||||
for b in &data.bars {
|
||||
all_bars.push(Bar {
|
||||
timestamp: b.t,
|
||||
open: b.o,
|
||||
high: b.h,
|
||||
low: b.l,
|
||||
close: b.c,
|
||||
volume: b.v,
|
||||
vwap: b.vw,
|
||||
});
|
||||
}
|
||||
|
||||
// Check for more pages
|
||||
if let Some(token) = data.next_page_token {
|
||||
if !token.is_empty() {
|
||||
page_token = Some(token);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
Ok(all_bars)
|
||||
}
|
||||
|
||||
/// Fetch historical bars for multiple symbols.
|
||||
pub async fn get_multi_historical_bars(
|
||||
&self,
|
||||
symbols: &[&str],
|
||||
timeframe: Timeframe,
|
||||
start: DateTime<Utc>,
|
||||
end: DateTime<Utc>,
|
||||
) -> Result<HashMap<String, Vec<Bar>>> {
|
||||
self.enforce_rate_limit().await;
|
||||
|
||||
let tf_str = match timeframe {
|
||||
Timeframe::Daily => "1Day",
|
||||
Timeframe::Hourly => "1Hour",
|
||||
};
|
||||
|
||||
let symbols_str = symbols.join(",");
|
||||
let url = format!(
|
||||
"{}/stocks/bars?symbols={}&timeframe={}&start={}&end={}&feed=iex&limit=10000",
|
||||
DATA_BASE_URL,
|
||||
symbols_str,
|
||||
tf_str,
|
||||
start.format("%Y-%m-%dT%H:%M:%SZ"),
|
||||
end.format("%Y-%m-%dT%H:%M:%SZ"),
|
||||
);
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.headers(self.auth_headers())
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to fetch bars")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("API error {}: {}", status, text);
|
||||
}
|
||||
|
||||
let data: BarsResponse = response.json().await.context("Failed to parse bars response")?;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
for (symbol, bars) in data.bars {
|
||||
let converted: Vec<Bar> = bars
|
||||
.iter()
|
||||
.map(|b| Bar {
|
||||
timestamp: b.t,
|
||||
open: b.o,
|
||||
high: b.h,
|
||||
low: b.l,
|
||||
close: b.c,
|
||||
volume: b.v,
|
||||
vwap: b.vw,
|
||||
})
|
||||
.collect();
|
||||
result.insert(symbol, converted);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get account information.
|
||||
pub async fn get_account(&self) -> Result<Account> {
|
||||
self.enforce_rate_limit().await;
|
||||
|
||||
let url = format!("{}/account", TRADING_BASE_URL);
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.headers(self.auth_headers())
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to get account")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("API error {}: {}", status, text);
|
||||
}
|
||||
|
||||
response.json().await.context("Failed to parse account")
|
||||
}
|
||||
|
||||
/// Get all positions.
|
||||
pub async fn get_positions(&self) -> Result<Vec<Position>> {
|
||||
self.enforce_rate_limit().await;
|
||||
|
||||
let url = format!("{}/positions", TRADING_BASE_URL);
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.headers(self.auth_headers())
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to get positions")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("API error {}: {}", status, text);
|
||||
}
|
||||
|
||||
response.json().await.context("Failed to parse positions")
|
||||
}
|
||||
|
||||
/// Get position for a specific symbol.
|
||||
pub async fn get_position(&self, symbol: &str) -> Result<Option<Position>> {
|
||||
self.enforce_rate_limit().await;
|
||||
|
||||
let url = format!("{}/positions/{}", TRADING_BASE_URL, symbol);
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.headers(self.auth_headers())
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to get position")?;
|
||||
|
||||
if response.status().as_u16() == 404 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("API error {}: {}", status, text);
|
||||
}
|
||||
|
||||
let position: Position = response.json().await.context("Failed to parse position")?;
|
||||
Ok(Some(position))
|
||||
}
|
||||
|
||||
/// Get market clock.
|
||||
pub async fn get_clock(&self) -> Result<Clock> {
|
||||
self.enforce_rate_limit().await;
|
||||
|
||||
let url = format!("{}/clock", TRADING_BASE_URL);
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.get(&url)
|
||||
.headers(self.auth_headers())
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to get clock")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("API error {}: {}", status, text);
|
||||
}
|
||||
|
||||
response.json().await.context("Failed to parse clock")
|
||||
}
|
||||
|
||||
/// Submit a market order.
|
||||
pub async fn submit_market_order(
|
||||
&self,
|
||||
symbol: &str,
|
||||
qty: f64,
|
||||
side: &str,
|
||||
) -> Result<Order> {
|
||||
self.enforce_rate_limit().await;
|
||||
|
||||
let url = format!("{}/orders", TRADING_BASE_URL);
|
||||
|
||||
let order_request = OrderRequest {
|
||||
symbol: symbol.to_string(),
|
||||
qty: qty.to_string(),
|
||||
side: side.to_string(),
|
||||
order_type: "market".to_string(),
|
||||
time_in_force: "day".to_string(),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.headers(self.auth_headers())
|
||||
.json(&order_request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to submit order")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("API error {}: {}", status, text);
|
||||
}
|
||||
|
||||
response.json().await.context("Failed to parse order")
|
||||
}
|
||||
|
||||
/// Check if market is open.
|
||||
pub async fn is_market_open(&self) -> Result<bool> {
|
||||
let clock = self.get_clock().await?;
|
||||
Ok(clock.is_open)
|
||||
}
|
||||
|
||||
/// Get next market open time.
|
||||
pub async fn get_next_market_open(&self) -> Result<DateTime<Utc>> {
|
||||
let clock = self.get_clock().await?;
|
||||
Ok(clock.next_open)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to fetch bars for backtesting with proper date handling.
|
||||
/// Fetches each symbol individually to avoid API limits on multi-symbol requests.
|
||||
pub async fn fetch_backtest_data(
|
||||
client: &AlpacaClient,
|
||||
symbols: &[&str],
|
||||
years: f64,
|
||||
timeframe: Timeframe,
|
||||
warmup_days: i64,
|
||||
) -> Result<HashMap<String, Vec<Bar>>> {
|
||||
let end = Utc::now();
|
||||
let days = (years * 365.0) as i64 + warmup_days + 30;
|
||||
let start = end - Duration::days(days);
|
||||
|
||||
tracing::info!(
|
||||
"Fetching {:.2} years of data ({} to {})...",
|
||||
years,
|
||||
start.format("%Y-%m-%d"),
|
||||
end.format("%Y-%m-%d")
|
||||
);
|
||||
|
||||
let mut all_data = HashMap::new();
|
||||
|
||||
// Fetch each symbol individually like Python does
|
||||
// The multi-symbol endpoint has a 10000 bar limit across ALL symbols
|
||||
for symbol in symbols {
|
||||
tracing::info!(" Fetching {}...", symbol);
|
||||
|
||||
match client
|
||||
.get_historical_bars(symbol, timeframe, start, end)
|
||||
.await
|
||||
{
|
||||
Ok(bars) => {
|
||||
if !bars.is_empty() {
|
||||
tracing::info!(" {}: {} bars loaded", symbol, bars.len());
|
||||
all_data.insert(symbol.to_string(), bars);
|
||||
} else {
|
||||
tracing::warn!(" {}: No data", symbol);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(" Failed to fetch {}: {}", symbol, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_data)
|
||||
}
|
||||
Reference in New Issue
Block a user