1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
| import os, re from decimal import Decimal, InvalidOperation from itertools import chain
TABLE_COLUMN = 13
sep = re.compile(r'\s+') def _sep(s): s = sep.sub(' ', s).strip() return s.split(' ')
def _path(name, root=None): return os.path.join(root or os.getcwd(),name)
def _getStrIncreasedWidth(s): ''' 大致计算以空格 为单位时字串增加的宽度 ''' return sum(ord(c) > 127 for c in s)
def _sort(s): try: return Decimal(s) except InvalidOperation: return s def _key(key): def wrapKey(*args, **kw): return _sort(key(*args, **kw)) return wrapKey
class Table(list): def addLine(self, line): self.append(list(line)) for index in range(self.column): column_width = len(line[index]) + _getStrIncreasedWidth(line[index]) self.column_width[index] = max(self.column_width[index], column_width) def addLines(self, data): self.extend((list(x) for x in data)) for index in range(self.column): max_column_width = max(map(lambda line:len(line[index]) + _getStrIncreasedWidth(line[index]), data)) self.column_width[index] = max(self.column_width[index], max_column_width) def getPrint(self): def formatLine(line): items = (f'{item:^{self.column_width[i]-_getStrIncreasedWidth(item)}}' for i,item in enumerate(line)) return '\t'.join(items) return '\n'.join(map(formatLine, chain([self.structure], self))) def __init__(self, structure, equivalent_field=None): self.structure = list(structure) self.column = len(structure) self.column_width = list(map(lambda x:len(x) + _getStrIncreasedWidth(x), structure)) self.equivalent_field = equivalent_field if equivalent_field else {} __repr__ = __str__ = getPrint def getColumn(self, index): return [line[index] for line in self] def addEquivalentField(self, equivalent_field): self.equivalent_field.update(equivalent_field) def getFiledName(self, field): return self.equivalent_field.get(field, field) def indexField(self, key): if callable(key): return [i for i,_field in enumerate(self.structure) if key(_field)] else: field = self.getFiledName(key) return [i for i,_field in enumerate(self.structure) if field == _field] def filter(self, index, key): table = Table(self.structure, self.equivalent_field) for line in self: if key(line[index]): table.addLine(line) return table def filterField(self, indexList): table = Table([self.structure[index] for index in indexList], self.equivalent_field) for line in self: table.addLine([line[index] for index in indexList]) return table def addColumn(self, field_name, key, index=None): if not index: index = self.column self.structure.insert(index, field_name) self.column_width.insert(index, len(field_name) + _getStrIncreasedWidth(field_name)) self.column += 1 for x,line in enumerate(self): item = key(self, x, index) line.insert(index, item) self.column_width[index] = max(self.column_width[index], len(item) + _getStrIncreasedWidth(item)) def sort(self, key, reverse=False): try: list.sort(self, key=_key(key), reverse=reverse) except TypeError: list.sort(self, key=key, reverse=reverse) def __getitem__(self, n): if isinstance(n, int): return list.__getitem__(self, n) else: table = Table(self.structure, self.equivalent_field) table.addLines(list.__getitem__(self, n)) return table
def pretreat(): ''' 文件内容预处理: 把数据文件stocks.txt中的箭头和逗号去除, 把x%转换为x/100, 去除掉最后一列(“股吧”), 去除掉数字前面的+。 在同一个目录下生成新的文件stockdata.txt ''' TABLE_STRUCTURE = re.compile(r'([\w/]{2,})[\t ]+') TABLE_DATA = re.compile(r'^((?:(?:SH)?[69](?:SZ)?[02](?:HK)?)0\d{4})\s+(\S+)\s+[↑↓]?(\S+)\s+\+?(\S+)\s+\+?(\S+)%\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+股吧',flags=re.IGNORECASE) with open(_path(r'stocks.txt'), encoding='utf-8') as stocks: structure = TABLE_STRUCTURE.findall(stocks.readline()) if len(structure) != TABLE_COLUMN: return 'stocks.txt格式有误, 请确认表头是否复制完全!' table = Table(structure) hundred = Decimal('100') for line in stocks: line = TABLE_DATA.findall(line.replace(',','')) if not line or len(line[0]) != TABLE_COLUMN: return f'stocks.txt格式有误, 请确认第{len(table)+2}行的数据是否复制完全!' line = list(line[0]) line[4] = str(Decimal(line[4]) / hundred) table.addLine(line) with open(_path(r'stockdata.txt'), 'w', encoding='utf-8') as stockdata: stockdata.write(table.getPrint()) return table
def readTable(): ''' 读入处理后的stockdata.txt文件 ''' with open(_path(r'stockdata.txt'), encoding='utf-8') as stockdata: table = Table(_sep(stockdata.readline())) for line in stockdata: table.addLine(_sep(line)) return table
def search(table): ''' 输入以空格隔开的若干个股票名称S,S如果没有输入,则默认为所有股票。 输入以空格隔开的若干个字段名T,T如果没有输入,则默认为”代码”。 输入一个要排序的字段名name(该字段必须在T中), 输入排序方式(升序、降序)。其中排序方式默认为升序。 根据这些条件,检索文件中符合条件的股票,并显示T字段名称及相应的内容。 前两个字段名不管有没有出现在T中,都要输出显示。 ''' S = input('请输入以空格隔开的若干个股票名称S:').strip() if S: S = _sep(S) temp_table = table.filter(table.indexField('名称')[0], lambda x:x in S) if len(temp_table) > 0: table = temp_table T = input('请输入以空格隔开的若干个字段名T:').strip() if not T: T = '代码' T = set(table.getFiledName(field) for field in _sep(T)) T.update({'代码','名称'}) table = table.filterField(table.indexField(lambda x: x in T)) name = input('请输入一个要排序的字段名name(该字段必须在T中):').strip() if not name: name = '代码' name = table.getFiledName(name) if not name in T: return f'排序的字段 {name} 必须在T中!' sort_index = table.indexField(name) if len(sort_index) != 1: return 'name 和 T 输入错误!' sort_index = sort_index[0] reverse = input('请输入排序方式(升序、降序):').strip() if reverse == '降序': reverse = True else: reverse = False table.sort(key=lambda x:x[sort_index], reverse=reverse) return table
def volatilityAnalysis(table): ''' 根据“最高”和“最低”两个字段, 计算波动率=(最高-最低)/最低, 列出波动率最大的前三个股票代码、名称、波动率。 ''' table = table.filter(table.indexField('最低')[0], eval) index_sup = table.indexField('最高')[0] index_inf = table.indexField('最低')[0] def _analysis(table, x, y): sup = Decimal(table[x][index_sup]) inf = Decimal(table[x][index_inf]) return f'{(sup-inf)/inf:.5}' table.addColumn('波动率', _analysis) T = {'代码', '名称', '波动率'} table = table.filterField(table.indexField(lambda x: x in T)) sort_index = table.indexField('波动率')[0] table.sort(key=lambda x:x[sort_index], reverse=True) return table
if os.path.exists(_path(r'stockdata.txt')): table = readTable() else: table = pretreat() assert isinstance(table, Table), table table.addEquivalentField({'成交量':'成交量/手', '成交额':'成交额/万'})
print('\n搜索结果如下:', search(table), sep='\n')
|