Rust · 15431 bytes Raw Blame History
1 //! Expression AST and constant evaluation for assembler directives and operands.
2
3 use std::collections::BTreeMap;
4 use std::fmt;
5
6 #[derive(Debug, Clone, PartialEq, Eq)]
7 pub enum Expr {
8 Int(i64),
9 Symbol(String),
10 ModifiedSymbol {
11 symbol: String,
12 modifier: SymbolModifier,
13 },
14 CurrentLocation,
15 UnaryMinus(Box<Expr>),
16 Add(Box<Expr>, Box<Expr>),
17 Sub(Box<Expr>, Box<Expr>),
18 }
19
20 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
21 pub enum SymbolModifier {
22 Got,
23 }
24
25 #[derive(Debug, Clone, PartialEq, Eq)]
26 pub enum EvalError {
27 UndefinedSymbol(String),
28 Overflow,
29 }
30
31 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
32 pub enum SymbolValue {
33 Absolute(i64),
34 Defined { section: usize, value: i64 },
35 Undefined,
36 }
37
38 #[derive(Debug, Clone, PartialEq, Eq)]
39 pub enum ClassifiedExpr {
40 Absolute(i64),
41 Relocatable {
42 symbol: String,
43 addend: i64,
44 },
45 Difference {
46 minuend: String,
47 subtrahend: String,
48 addend: i64,
49 },
50 PointerToGot {
51 symbol: String,
52 addend: i64,
53 pcrel: bool,
54 },
55 }
56
57 #[derive(Debug, Clone, PartialEq, Eq)]
58 pub enum ClassifyError {
59 Overflow,
60 Illegal(String),
61 }
62
63 impl fmt::Display for EvalError {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 match self {
66 Self::UndefinedSymbol(symbol) => write!(f, "undefined absolute symbol '{}'", symbol),
67 Self::Overflow => write!(f, "expression overflows i64"),
68 }
69 }
70 }
71
72 impl std::error::Error for EvalError {}
73
74 impl fmt::Display for ClassifyError {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::Overflow => write!(f, "expression overflows i64"),
78 Self::Illegal(msg) => write!(f, "{}", msg),
79 }
80 }
81 }
82
83 impl std::error::Error for ClassifyError {}
84
85 pub fn eval_pure(expr: &Expr) -> Result<i64, EvalError> {
86 eval_with_symbols(expr, &BTreeMap::new())
87 }
88
89 pub fn eval_with_symbols(expr: &Expr, symbols: &BTreeMap<String, i64>) -> Result<i64, EvalError> {
90 match expr {
91 Expr::Int(value) => Ok(*value),
92 Expr::Symbol(symbol) => symbols
93 .get(symbol)
94 .copied()
95 .ok_or_else(|| EvalError::UndefinedSymbol(symbol.clone())),
96 Expr::ModifiedSymbol { symbol, .. } => {
97 Err(EvalError::UndefinedSymbol(format!("{}@GOT", symbol)))
98 }
99 Expr::CurrentLocation => Err(EvalError::UndefinedSymbol(".".into())),
100 Expr::UnaryMinus(inner) => eval_with_symbols(inner, symbols)?
101 .checked_neg()
102 .ok_or(EvalError::Overflow),
103 Expr::Add(lhs, rhs) => eval_with_symbols(lhs, symbols)?
104 .checked_add(eval_with_symbols(rhs, symbols)?)
105 .ok_or(EvalError::Overflow),
106 Expr::Sub(lhs, rhs) => eval_with_symbols(lhs, symbols)?
107 .checked_sub(eval_with_symbols(rhs, symbols)?)
108 .ok_or(EvalError::Overflow),
109 }
110 }
111
112 pub fn referenced_symbols(expr: &Expr) -> Vec<String> {
113 let mut symbols = Vec::new();
114 collect_symbols(expr, &mut symbols);
115 symbols
116 }
117
118 pub fn classify(
119 expr: &Expr,
120 symbols: &BTreeMap<String, SymbolValue>,
121 ) -> Result<ClassifiedExpr, ClassifyError> {
122 let mut constant = 0i64;
123 let mut terms = Vec::new();
124 linearize(expr, 1, &mut constant, &mut terms)?;
125 terms.retain(|(_, coeff)| *coeff != 0);
126
127 let mut defined_groups: BTreeMap<usize, Vec<(String, i64, i32)>> = BTreeMap::new();
128 let mut defined_order = Vec::new();
129 let mut remaining = Vec::new();
130 let mut got_terms = Vec::new();
131 let mut current_location_coeff = 0;
132
133 for (term, coeff) in terms {
134 match term {
135 Term::Plain(symbol) => match symbols
136 .get(&symbol)
137 .copied()
138 .unwrap_or(SymbolValue::Undefined)
139 {
140 SymbolValue::Absolute(value) => {
141 constant = checked_add(constant, checked_mul(value, coeff as i64)?)?;
142 }
143 SymbolValue::Defined { section, value } => {
144 if !defined_groups.contains_key(&section) {
145 defined_order.push(section);
146 }
147 defined_groups
148 .entry(section)
149 .or_default()
150 .push((symbol, value, coeff));
151 }
152 SymbolValue::Undefined => push_plain_term(&mut remaining, symbol, coeff),
153 },
154 Term::Modified {
155 symbol,
156 modifier: SymbolModifier::Got,
157 } => {
158 push_plain_term(&mut got_terms, symbol, coeff);
159 }
160 Term::CurrentLocation => {
161 current_location_coeff += coeff;
162 }
163 }
164 }
165
166 for section in defined_order {
167 let group = defined_groups
168 .remove(&section)
169 .expect("section group should exist");
170 let section_sum: i32 = group.iter().map(|(_, _, coeff)| *coeff).sum();
171 match section_sum {
172 0 => {
173 for (_, value, coeff) in group {
174 constant = checked_add(constant, checked_mul(value, coeff as i64)?)?;
175 }
176 }
177 1 | -1 => {
178 let (anchor_symbol, anchor_value, _) = &group[0];
179 for (_, value, coeff) in &group {
180 constant = checked_add(
181 constant,
182 checked_mul(checked_sub(*value, *anchor_value)?, *coeff as i64)?,
183 )?;
184 }
185 push_plain_term(&mut remaining, anchor_symbol.clone(), section_sum);
186 }
187 other => {
188 return Err(ClassifyError::Illegal(format!(
189 "expression has unsupported section-relative coefficient {}",
190 other
191 )));
192 }
193 }
194 }
195
196 remaining.retain(|(_, coeff)| *coeff != 0);
197 got_terms.retain(|(_, coeff)| *coeff != 0);
198
199 if !got_terms.is_empty() || current_location_coeff != 0 {
200 if !remaining.is_empty() {
201 return Err(ClassifyError::Illegal(
202 "pointer-to-GOT expression cannot be combined with plain relocatable symbols"
203 .into(),
204 ));
205 }
206 if got_terms.len() != 1 || got_terms[0].1 != 1 {
207 return Err(ClassifyError::Illegal(
208 "expression is not representable as a pointer-to-GOT relocation".into(),
209 ));
210 }
211 if current_location_coeff != 0 && current_location_coeff != -1 {
212 return Err(ClassifyError::Illegal(
213 "pointer-to-GOT expression may subtract current location only once".into(),
214 ));
215 }
216 return Ok(ClassifiedExpr::PointerToGot {
217 symbol: got_terms[0].0.clone(),
218 addend: constant,
219 pcrel: current_location_coeff == -1,
220 });
221 }
222
223 match remaining.as_slice() {
224 [] => Ok(ClassifiedExpr::Absolute(constant)),
225 [(symbol, 1)] => Ok(ClassifiedExpr::Relocatable {
226 symbol: symbol.clone(),
227 addend: constant,
228 }),
229 [(minuend, 1), (subtrahend, -1)] => Ok(ClassifiedExpr::Difference {
230 minuend: minuend.clone(),
231 subtrahend: subtrahend.clone(),
232 addend: constant,
233 }),
234 [(subtrahend, -1), (minuend, 1)] => Ok(ClassifiedExpr::Difference {
235 minuend: minuend.clone(),
236 subtrahend: subtrahend.clone(),
237 addend: constant,
238 }),
239 _ => Err(ClassifyError::Illegal(
240 "expression is not representable as an absolute value or relocation".into(),
241 )),
242 }
243 }
244
245 fn collect_symbols(expr: &Expr, out: &mut Vec<String>) {
246 match expr {
247 Expr::Int(_) => {}
248 Expr::Symbol(symbol) => out.push(symbol.clone()),
249 Expr::ModifiedSymbol { symbol, .. } => out.push(symbol.clone()),
250 Expr::CurrentLocation => {}
251 Expr::UnaryMinus(inner) => collect_symbols(inner, out),
252 Expr::Add(lhs, rhs) | Expr::Sub(lhs, rhs) => {
253 collect_symbols(lhs, out);
254 collect_symbols(rhs, out);
255 }
256 }
257 }
258
259 fn linearize(
260 expr: &Expr,
261 sign: i32,
262 constant: &mut i64,
263 terms: &mut Vec<(Term, i32)>,
264 ) -> Result<(), ClassifyError> {
265 match expr {
266 Expr::Int(value) => {
267 let signed = if sign == 1 {
268 *value
269 } else {
270 value.checked_neg().ok_or(ClassifyError::Overflow)?
271 };
272 *constant = checked_add(*constant, signed)?;
273 Ok(())
274 }
275 Expr::Symbol(symbol) => {
276 push_term(terms, Term::Plain(symbol.clone()), sign);
277 Ok(())
278 }
279 Expr::ModifiedSymbol { symbol, modifier } => {
280 push_term(
281 terms,
282 Term::Modified {
283 symbol: symbol.clone(),
284 modifier: *modifier,
285 },
286 sign,
287 );
288 Ok(())
289 }
290 Expr::CurrentLocation => {
291 push_term(terms, Term::CurrentLocation, sign);
292 Ok(())
293 }
294 Expr::UnaryMinus(inner) => linearize(inner, -sign, constant, terms),
295 Expr::Add(lhs, rhs) => {
296 linearize(lhs, sign, constant, terms)?;
297 linearize(rhs, sign, constant, terms)
298 }
299 Expr::Sub(lhs, rhs) => {
300 linearize(lhs, sign, constant, terms)?;
301 linearize(rhs, -sign, constant, terms)
302 }
303 }
304 }
305
306 #[derive(Debug, Clone, PartialEq, Eq)]
307 enum Term {
308 Plain(String),
309 Modified {
310 symbol: String,
311 modifier: SymbolModifier,
312 },
313 CurrentLocation,
314 }
315
316 fn push_term(terms: &mut Vec<(Term, i32)>, term: Term, delta: i32) {
317 if let Some((_, coeff)) = terms.iter_mut().find(|(existing, _)| *existing == term) {
318 *coeff += delta;
319 } else {
320 terms.push((term, delta));
321 }
322 }
323
324 fn push_plain_term(terms: &mut Vec<(String, i32)>, symbol: String, delta: i32) {
325 if let Some((_, coeff)) = terms.iter_mut().find(|(name, _)| *name == symbol) {
326 *coeff += delta;
327 } else {
328 terms.push((symbol, delta));
329 }
330 }
331
332 fn checked_add(lhs: i64, rhs: i64) -> Result<i64, ClassifyError> {
333 lhs.checked_add(rhs).ok_or(ClassifyError::Overflow)
334 }
335
336 fn checked_sub(lhs: i64, rhs: i64) -> Result<i64, ClassifyError> {
337 lhs.checked_sub(rhs).ok_or(ClassifyError::Overflow)
338 }
339
340 fn checked_mul(lhs: i64, rhs: i64) -> Result<i64, ClassifyError> {
341 lhs.checked_mul(rhs).ok_or(ClassifyError::Overflow)
342 }
343
344 #[cfg(test)]
345 mod tests {
346 use super::*;
347
348 #[test]
349 fn eval_pure_constant_expression() {
350 let expr = Expr::Sub(
351 Box::new(Expr::Add(Box::new(Expr::Int(1)), Box::new(Expr::Int(2)))),
352 Box::new(Expr::Int(3)),
353 );
354 assert_eq!(eval_pure(&expr).unwrap(), 0);
355 }
356
357 #[test]
358 fn eval_symbolic_expression() {
359 let expr = Expr::Add(Box::new(Expr::Symbol("foo".into())), Box::new(Expr::Int(4)));
360 let mut symbols = BTreeMap::new();
361 symbols.insert("foo".into(), 8);
362 assert_eq!(eval_with_symbols(&expr, &symbols).unwrap(), 12);
363 }
364
365 #[test]
366 fn eval_undefined_symbol_errors() {
367 let err = eval_pure(&Expr::Symbol("foo".into())).unwrap_err();
368 assert_eq!(err, EvalError::UndefinedSymbol("foo".into()));
369 }
370
371 #[test]
372 fn classify_relocatable_symbol_plus_constant() {
373 let expr = Expr::Add(Box::new(Expr::Symbol("foo".into())), Box::new(Expr::Int(4)));
374 let mut symbols = BTreeMap::new();
375 symbols.insert(
376 "foo".into(),
377 SymbolValue::Defined {
378 section: 1,
379 value: 12,
380 },
381 );
382 assert_eq!(
383 classify(&expr, &symbols).unwrap(),
384 ClassifiedExpr::Relocatable {
385 symbol: "foo".into(),
386 addend: 4
387 }
388 );
389 }
390
391 #[test]
392 fn classify_same_section_difference_as_absolute() {
393 let expr = Expr::Sub(
394 Box::new(Expr::Symbol("foo".into())),
395 Box::new(Expr::Symbol("bar".into())),
396 );
397 let mut symbols = BTreeMap::new();
398 symbols.insert(
399 "foo".into(),
400 SymbolValue::Defined {
401 section: 1,
402 value: 16,
403 },
404 );
405 symbols.insert(
406 "bar".into(),
407 SymbolValue::Defined {
408 section: 1,
409 value: 24,
410 },
411 );
412 assert_eq!(
413 classify(&expr, &symbols).unwrap(),
414 ClassifiedExpr::Absolute(-8)
415 );
416 }
417
418 #[test]
419 fn classify_external_difference() {
420 let expr = Expr::Add(
421 Box::new(Expr::Sub(
422 Box::new(Expr::Symbol("_foo".into())),
423 Box::new(Expr::Symbol("_bar".into())),
424 )),
425 Box::new(Expr::Int(4)),
426 );
427 let symbols = BTreeMap::new();
428 assert_eq!(
429 classify(&expr, &symbols).unwrap(),
430 ClassifiedExpr::Difference {
431 minuend: "_foo".into(),
432 subtrahend: "_bar".into(),
433 addend: 4,
434 }
435 );
436 }
437
438 #[test]
439 fn classify_section_sum_one_folds_difference_into_addend() {
440 let expr = Expr::Add(
441 Box::new(Expr::Symbol("foo".into())),
442 Box::new(Expr::Sub(
443 Box::new(Expr::Symbol("bar".into())),
444 Box::new(Expr::Symbol("baz".into())),
445 )),
446 );
447 let mut symbols = BTreeMap::new();
448 symbols.insert(
449 "foo".into(),
450 SymbolValue::Defined {
451 section: 1,
452 value: 16,
453 },
454 );
455 symbols.insert(
456 "bar".into(),
457 SymbolValue::Defined {
458 section: 1,
459 value: 24,
460 },
461 );
462 symbols.insert(
463 "baz".into(),
464 SymbolValue::Defined {
465 section: 1,
466 value: 20,
467 },
468 );
469 assert_eq!(
470 classify(&expr, &symbols).unwrap(),
471 ClassifiedExpr::Relocatable {
472 symbol: "foo".into(),
473 addend: 4
474 }
475 );
476 }
477
478 #[test]
479 fn classify_pointer_to_got() {
480 let expr = Expr::ModifiedSymbol {
481 symbol: "_puts".into(),
482 modifier: SymbolModifier::Got,
483 };
484 let symbols = BTreeMap::new();
485 assert_eq!(
486 classify(&expr, &symbols).unwrap(),
487 ClassifiedExpr::PointerToGot {
488 symbol: "_puts".into(),
489 addend: 0,
490 pcrel: false,
491 }
492 );
493 }
494
495 #[test]
496 fn classify_pointer_to_got_pcrel() {
497 let expr = Expr::Sub(
498 Box::new(Expr::ModifiedSymbol {
499 symbol: "_puts".into(),
500 modifier: SymbolModifier::Got,
501 }),
502 Box::new(Expr::CurrentLocation),
503 );
504 let symbols = BTreeMap::new();
505 assert_eq!(
506 classify(&expr, &symbols).unwrap(),
507 ClassifiedExpr::PointerToGot {
508 symbol: "_puts".into(),
509 addend: 0,
510 pcrel: true,
511 }
512 );
513 }
514 }
515