acadia/yunq/parser.py

296 lines
8.7 KiB
Python
Raw Normal View History

from enum import Enum
2023-10-24 12:40:15 -07:00
import sys
class LexemeType(Enum):
NONE = 0
EOF = 1
# Identifiers and Keywords
NAME = 2
# Symbols
LEFT_BRACE = 3
RIGHT_BRACE = 4
LEFT_PAREN = 5
RIGHT_PAREN = 6
ARROW = 7
SEMICOLON = 8
class Lexeme():
def __init__(self, lextype: LexemeType, value = None):
self.type = lextype
self.value = value
def __str__(self):
if self.value:
return "(%s, %s)" % (self.type, self.value)
return "(%s)" % self.type
def __repr__(self):
return self.__str__()
def lexer(program: str):
line = 1
start = 0
current = 0
tokens: list[Lexeme] = []
while current < len(program):
# Scan next token.
start = current
curr = program[current]
if curr == '\n':
line += 1
elif curr == '\t' or curr == ' ' or curr == '\r':
pass
elif curr == '{':
tokens.append(Lexeme(LexemeType.LEFT_BRACE))
elif curr == '}':
tokens.append(Lexeme(LexemeType.RIGHT_BRACE))
elif curr == '(':
tokens.append(Lexeme(LexemeType.LEFT_PAREN))
elif curr == ')':
tokens.append(Lexeme(LexemeType.RIGHT_PAREN))
elif curr == ';':
tokens.append(Lexeme(LexemeType.SEMICOLON))
elif curr == '-':
current += 1
if program[current] == '>':
tokens.append(Lexeme(LexemeType.ARROW))
else:
sys.exit("Expected > after - got '%s' on line %d" % (program[current], line))
elif curr.isalpha():
while program[current + 1].isalnum() or program[current + 1] == '_':
current += 1
tokens.append(Lexeme(LexemeType.NAME, program[start:current + 1]))
2023-11-09 11:33:12 -08:00
elif curr == '/' and program[current + 1] == '/':
while program[current] != '\n':
current += 1
else:
sys.exit("Got unexpected token %s on line %s." % (curr, line))
current += 1
tokens.append(Lexeme(LexemeType.EOF))
return tokens
class Method():
def __init__(self, name: str, request: str, response: str):
self.name = name
self.request = request
self.response = response
class Interface():
def __init__(self, name: str, methods: list[Method]):
self.name = name
self.methods = methods
class Type(Enum):
NONE = 0
U64 = 1
I64 = 2
STRING = 3
BYTES = 4
CAPABILITY = 5
type_str_dict = {
"u64": Type.U64,
"i64": Type.I64,
"string": Type.STRING,
"bytes": Type.BYTES,
"capability": Type.CAPABILITY,
}
type_to_cppstr = {
Type.U64: "uint64_t",
Type.I64: "int64_t",
Type.STRING: "glcr::String",
Type.CAPABILITY: "z_cap_t",
Type.BYTES: "glcr::Vector<uint8_t>"
}
class Field():
def __init__(self, fieldtype: Type, name: str, repeated = False):
self.type = fieldtype
self.name = name
self.repeated = repeated
def cpp_type(self):
return type_to_cppstr[self.type]
class Message():
def __init__(self, name: str, fields: list[Field]):
self.name = name
self.fields = fields
Decl = Interface | Message
name_dict: dict[str, Decl] = {}
class Parser():
def __init__(self, tokens: list[Lexeme]):
self.tokens = tokens
self.current = 0
def peektype(self) -> LexemeType:
return self.tokens[self.current].type
def peekvalue(self) -> str:
return self.tokens[self.current].value
def consume(self) -> Lexeme:
self.current += 1
return self.tokens[self.current - 1]
def consume_identifier(self) -> str:
tok = self.consume()
if tok.type != LexemeType.NAME:
sys.exit("Expected identifier got %s" % tok.type)
return tok.value
def consume_check(self, lex_type: LexemeType):
tok = self.consume()
if tok.type != lex_type:
2023-11-25 13:07:27 -08:00
sys.exit("Expected %s got %s" % (lex_type, tok.type))
def consume_check_identifier(self, name: str):
tok = self.consume()
if tok.type != LexemeType.NAME:
sys.exit("Expected '%s' got a %s" % (name, tok.type))
if tok.value != name:
sys.exit("Expected '%s' got '%s'" % (name, tok.value))
def parse(self) -> list[Decl]:
decls = []
while self.peektype() != LexemeType.EOF:
decls.append(self.decl())
return decls
def decl(self) -> Decl:
token = self.consume()
if token.type != LexemeType.NAME:
sys.exit("Unexpected token: %s", token)
if token.value == "message":
return self.message()
elif token.value == "interface":
return self.interface()
sys.exit("Unexpected identifier '%s', expected message or interface" % token.value)
def interface(self):
# "interface" consumed by decl.
name = self.consume_identifier()
if name in name_dict.keys():
sys.exit("Name '%s' already exists." % name)
self.consume_check(LexemeType.LEFT_BRACE)
methods: list[Method] = []
method_names = set()
while self.peektype() != LexemeType.RIGHT_BRACE:
m = self.method()
if m.name in method_names:
sys.exit("Method %s declared twice on %s" % (m.name, name))
method_names.add(m.name)
methods.append(m)
self.consume_check(LexemeType.RIGHT_BRACE)
i = Interface(name, methods)
name_dict[name] = i
return i
def method(self):
self.consume_check_identifier("method")
name = self.consume_identifier()
self.consume_check(LexemeType.LEFT_PAREN)
request = self.consume_identifier()
self.consume_check(LexemeType.RIGHT_PAREN)
self.consume_check(LexemeType.ARROW)
self.consume_check(LexemeType.LEFT_PAREN)
response = self.consume_identifier()
self.consume_check(LexemeType.RIGHT_PAREN)
self.consume_check(LexemeType.SEMICOLON)
return Method(name, request, response)
def message(self):
# "message" consumed by decl.
name = self.consume_identifier()
if name in name_dict:
sys.exit("Name '%s' already exists." % name)
self.consume_check(LexemeType.LEFT_BRACE)
fields: list[Field] = []
field_names = set()
while self.peektype() != LexemeType.RIGHT_BRACE:
f = self.field()
if f.name in field_names:
sys.exit("Field %s declared twice on %s" % (f.name, name))
field_names.add(f.name)
fields.append(f)
self.consume_check(LexemeType.RIGHT_BRACE)
m = Message(name, fields)
name_dict[name] = m
return m
def field(self):
repeated = False
field_type_str = self.consume_identifier()
if field_type_str == "repeated":
repeated = True
field_type_str = self.consume_identifier()
if field_type_str not in type_str_dict.keys():
sys.exit("Expected type got '%s'" % field_type_str)
field_type = type_str_dict[field_type_str]
name = self.consume_identifier()
self.consume_check(LexemeType.SEMICOLON)
return Field(field_type, name, repeated)
def type_check(decls: list[Decl]):
for decl in decls:
if type(decl) is Interface:
for method in decl.methods:
if method.request not in name_dict.keys():
sys.exit("Request type '%s' for '%s.%s' does not exist" % (method.request, decl.name, method.name))
if type(name_dict[method.request]) is not Message:
sys.exit("Request type '%s' for '%s.%s' should be a message" % (method.request, decl.name, method.name))
if method.response not in name_dict.keys():
sys.exit("Response type '%s' for '%s.%s' does not exist" % (method.response, decl.name, method.name))
if type(name_dict[method.response]) is not Message:
sys.exit("Response type '%s' for '%s.%s' should be a message" % (method.response, decl.name, method.name))
def print_ast(decls: list[Decl]):
for decl in decls:
if type(decl) is Interface:
print("%s (Interface)" % decl.name)
for method in decl.methods:
print("\t%s (%s -> %s)" % (method.name, method.request, method.response))
elif type(decl) is Message:
print("%s (Message)" % decl.name)
for field in decl.fields:
print("\t%s %s" % (field.type.name, field.name))
else:
print("unknown type")