Extended Berkeley Packet Filter (eBPF) assembler and virtual machine
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
6.8KB

  1. package ebpf
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "math"
  8. "unsafe"
  9. )
  10. var (
  11. ErrNoProgram = errors.New("ebpf: no program is loaded")
  12. errDivisionByZero = errors.New("division by zero")
  13. )
  14. const (
  15. // stackSize (in bytes)
  16. stackSize = 512
  17. //stackAddr = math.MaxUint16 - stackSize - 1
  18. //memoryAddr = math.MaxUint16 + 1
  19. // maxInstructions is the maximum number of instructions the VM accepts.
  20. maxInstructions = 4096
  21. )
  22. type Func func(r1, r2, r3, r4, r5 uint64) uint64
  23. type VM struct {
  24. // Register memory.
  25. Register [11]uint64
  26. // Stack memory.
  27. Stack [stackSize]byte
  28. // PC is the program counter.
  29. PC uint32
  30. // Tracer provides callbacks for the VM.
  31. Tracer
  32. memory []byte
  33. memoryAddr uint64
  34. memorySize uint64
  35. stackAddr uint64
  36. program Program
  37. funcs map[Call]Func
  38. }
  39. func New() *VM {
  40. return &VM{
  41. funcs: make(map[Call]Func),
  42. }
  43. }
  44. func (vm *VM) Load(program Program) error {
  45. if l := len(program); l > maxInstructions {
  46. return fmt.Errorf("ebpf: %d instructions exceeds maximum of %d", l, maxInstructions)
  47. }
  48. if err := program.Verify(); err != nil {
  49. return err
  50. }
  51. vm.program = make(Program, len(program))
  52. copy(vm.program, program)
  53. return nil
  54. }
  55. // Attach a function.
  56. func (vm *VM) Attach(call Call, fn Func) {
  57. vm.funcs[call] = fn
  58. }
  59. // Run the eBPF program that was loaded.
  60. func (vm *VM) Run(memory []byte) (uint64, error) {
  61. if len(vm.program) == 0 {
  62. return math.MaxUint64, ErrNoProgram
  63. }
  64. // Reset register and stack.
  65. vm.reset(memory)
  66. // Trace start.
  67. if vm.Tracer != nil {
  68. vm.TraceStart()
  69. }
  70. execution:
  71. for vm.PC < uint32(len(vm.program)) {
  72. var (
  73. pc = vm.PC
  74. in = vm.program[pc]
  75. err error
  76. )
  77. if vm.Tracer != nil {
  78. vm.TracePre(in)
  79. }
  80. switch in := in.(type) {
  81. case ALUOpConstant:
  82. err = vm.aluOp(false, in.Dst, in.Op, uint64(in.Value))
  83. case ALUOpRegister:
  84. err = vm.aluOp(false, in.Dst, in.Op, uint64(uint32(vm.Register[in.Src])))
  85. case ALU64OpConstant:
  86. err = vm.aluOp(true, in.Dst, in.Op, uint64(in.Value))
  87. case ALU64OpRegister:
  88. err = vm.aluOp(true, in.Dst, in.Op, vm.Register[in.Src])
  89. case Negate:
  90. vm.Register[in.Dst] = uint64(^uint32(vm.Register[in.Dst]))
  91. case Negate64:
  92. vm.Register[in.Dst] = ^vm.Register[in.Dst]
  93. case Call:
  94. err = vm.call(in)
  95. case Exit:
  96. break execution
  97. case Jump:
  98. vm.PC += uint32(in.Offset)
  99. err = vm.checkPC()
  100. case JumpIf:
  101. if in.Cond.Match(vm.Register[in.Dst], uint64(in.Value)) {
  102. vm.PC += uint32(in.Offset)
  103. err = vm.checkPC()
  104. }
  105. case JumpIfX:
  106. if in.Cond.Match(vm.Register[in.Dst], vm.Register[in.Src]) {
  107. vm.PC += uint32(in.Offset)
  108. err = vm.checkPC()
  109. }
  110. case LoadConstant:
  111. vm.Register[in.Dst] = in.Value
  112. case LoadIndirect: // TODO(maze): intended to be used in socket filters, and are therefore not general-purpose
  113. case LoadAbsolute: // TODO(maze): intended to be used in socket filters, and are therefore not general-purpose
  114. case LoadRegister:
  115. vm.Register[in.Dst], err = vm.load(uint64(int64(vm.Register[in.Src])+int64(in.Offset)), in.Size)
  116. case StoreImmediate:
  117. err = vm.store(uint64(int64(vm.Register[in.Dst])+int64(in.Offset)), uint64(in.Value), in.Size)
  118. case StoreRegister:
  119. err = vm.store(uint64(int64(vm.Register[in.Dst])+int64(in.Offset)), vm.Register[in.Src], in.Size)
  120. case ByteSwap:
  121. vm.Register[in.Dst] = in.Swap(vm.Register[in.Dst])
  122. default:
  123. err = fmt.Errorf("unhandled instruction %q", in)
  124. }
  125. if err != nil {
  126. return math.MaxUint64, fmt.Errorf("ebpf: at PC=%#04x: %w", pc, err)
  127. }
  128. if vm.Tracer != nil {
  129. vm.TracePost(in)
  130. }
  131. vm.PC++
  132. }
  133. if vm.Tracer != nil {
  134. vm.TraceEnded()
  135. }
  136. return vm.Register[R0], nil
  137. }
  138. func (vm *VM) aluOp(wide bool, dst Register, op ALUOp, value uint64) error {
  139. switch op {
  140. case ALUOpAdd:
  141. vm.Register[dst] += value
  142. case ALUOpSub:
  143. vm.Register[dst] -= value
  144. case ALUOpMul:
  145. vm.Register[dst] *= value
  146. case ALUOpDiv:
  147. if value == 0 {
  148. return errDivisionByZero
  149. }
  150. vm.Register[dst] /= value
  151. case ALUOpOr:
  152. vm.Register[dst] |= value
  153. case ALUOpAnd:
  154. vm.Register[dst] &= value
  155. case ALUOpShiftLeft:
  156. vm.Register[dst] <<= value
  157. case ALUOpShiftRight:
  158. vm.Register[dst] >>= value
  159. case ALUOpMod:
  160. if value == 0 {
  161. return errDivisionByZero
  162. }
  163. vm.Register[dst] %= value
  164. case ALUOpXor:
  165. vm.Register[dst] ^= value
  166. case ALUOpMove:
  167. vm.Register[dst] = value
  168. case ALUOpArithmicShiftRight:
  169. vm.Register[dst] = uint64(int64(vm.Register[dst]) >> value)
  170. }
  171. if !wide {
  172. vm.Register[dst] &= math.MaxUint32
  173. }
  174. return nil
  175. }
  176. func (vm *VM) call(call Call) error {
  177. if fn, ok := vm.funcs[call]; ok {
  178. vm.Register[R0] = fn(vm.Register[R1], vm.Register[R2], vm.Register[R3], vm.Register[R4], vm.Register[R5])
  179. return nil
  180. }
  181. return fmt.Errorf("no function %#08x registered", uint32(call))
  182. }
  183. func (vm *VM) checkPC() error {
  184. if vm.PC >= uint32(len(vm.program)) {
  185. // TODO is this needed?
  186. }
  187. return nil
  188. }
  189. func (vm *VM) load(addr uint64, size uint8) (out uint64, err error) {
  190. var memory []byte
  191. if memory, addr, err = vm.checkMemory(addr, size, "load"); err != nil {
  192. return
  193. }
  194. switch size {
  195. case 1:
  196. return uint64(memory[addr]), nil
  197. case 2:
  198. return uint64(binary.LittleEndian.Uint16(memory[addr:])), nil
  199. case 4:
  200. return uint64(binary.LittleEndian.Uint32(memory[addr:])), nil
  201. case 8:
  202. return binary.LittleEndian.Uint64(memory[addr:]), nil
  203. default:
  204. panic("unreachable")
  205. }
  206. }
  207. func (vm *VM) store(addr, value uint64, size uint8) (err error) {
  208. log.Printf("store at %#x: %#x (size %d)", addr, value, size)
  209. var memory []byte
  210. if memory, addr, err = vm.checkMemory(addr, size, "store"); err != nil {
  211. return err
  212. }
  213. switch size {
  214. case 1:
  215. memory[addr] = uint8(value)
  216. case 2:
  217. binary.LittleEndian.PutUint16(memory[addr:], uint16(value))
  218. case 4:
  219. binary.LittleEndian.PutUint32(memory[addr:], uint32(value))
  220. case 8:
  221. binary.LittleEndian.PutUint64(memory[addr:], value)
  222. }
  223. return nil
  224. }
  225. func (vm *VM) checkMemory(addr uint64, size uint8, operation string) ([]byte, uint64, error) {
  226. if vm.memoryAddr <= addr && addr+uint64(size) <= vm.memoryAddr+vm.memorySize {
  227. // Memory access
  228. return vm.memory, addr - vm.memoryAddr, nil
  229. }
  230. if vm.stackAddr <= addr && addr+uint64(size) <= vm.stackAddr+stackSize {
  231. // Stack access
  232. return vm.Stack[:], addr - vm.stackAddr, nil
  233. }
  234. return nil, 0, fmt.Errorf("out of bounds memory %s at %#x+%d", operation, addr, size)
  235. }
  236. var (
  237. emptyRegisters [11]uint64
  238. emptyStack [stackSize]byte
  239. )
  240. func (vm *VM) reset(memory []byte) {
  241. copy(vm.Register[:], emptyRegisters[:])
  242. copy(vm.Stack[:], emptyStack[:])
  243. vm.memory = memory
  244. vm.memoryAddr = uint64(uintptr(unsafe.Pointer(&memory[0])))
  245. vm.memorySize = uint64(len(memory))
  246. vm.stackAddr = uint64(uintptr(unsafe.Pointer(&vm.Stack[0])))
  247. vm.Register[R1] = vm.memoryAddr
  248. vm.Register[R10] = vm.stackAddr + stackSize - 1 // end of stack
  249. }
  250. type Tracer interface {
  251. TraceStart()
  252. TracePre(instruction Instruction)
  253. TracePost(instruction Instruction)
  254. TraceEnded()
  255. }