Let us walk on the 3-isogeny graph
Loading...
Searching...
No Matches
AsmMontRedc.py
Go to the documentation of this file.
1#!/usr/bin/env sage -python
2
3
4# how to call
5# sage -python AsmMultCodegenerator.py > fp9216.s
6
7from sage.all import *
8import math
9
10sefOfLimbs = set()
11printOut = 0
12
13def rotate(l, x):
14 return l[-x:] + l[:-x]
15
16def push():
17 # S = "# -------------------\n"
18 S = "# push\n"
19 # S = S + " push rbx\n push rbp\n push rdi\n push rsi\n push r12\n push r13\n push r14\n push r15\n\n"
20 S = S + " push rbx\n push rbp\n push rsi\n push r12\n push r13\n push r14\n push r15\n\n"
21 # S = S + " push rdx\n push rdi\n push rsi\n\n"
22 return S
23
24def pop():
25 # S = "# -------------------\n"
26 S = "# pop\n"
27 # S = S + " pop r15\n pop r14\n pop r13\n pop r12\n pop rsi\n pop rdi\n pop rbp\n pop rbx\n\n"
28 S = S + " pop r15\n pop r14\n pop r13\n pop r12\n pop rsi\n pop rbp\n pop rbx\n\n"
29 # S = S + " pop rsi\n pop rdi\n pop rdx\n\n"
30 return S
31
32
33def MontRedcAdd(plimbs):
34
35
36 # registers reserved rdi, rsi, rdx
37 # rax, rbx = rcx, r8
38 registers = ["r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"]
39 # if(plimbs > len(registers)):
40 # print("ERROR: Index out range")
41 # exit()
42 state = registers[:plimbs]
43 #state = registers
44
45 S = ""
46 # S = ".intel_syntax noprefix\n\n"
47 # S = S + ".section .rodata\n\n"
48 # S = S + ".section .text\n\n"
49
50 # S = S + ".macro p_times_w\n"
51 # S = S + "mult_"+ str(plimbs) + "x" + str(plimbs) + ":\n"
52
53 S = S + ".global u_i_times_p\nu_i_times_p:\n"
54
55 S = S + push()
56
57 S = S + " lea rdx, [rdi + 0] ]\n"
58# intro
59 S = S + " mov rax, [rsi + 0]\n"
60 S = S + " add rax, [rdx + 0]\n"
61 S = S + " mov [rdi + 0], rax\n"
62# loop
63 S = S + " .set k, 1\n"
64 S = S + " .rept " + str(plimbs) + "\n"
65 S = S + " mov rax, [rsi + 8*k]\n"
66 S = S + " adc rax, [rdx + 8*k]\n"
67 S = S + " mov [rdi + 8*k], rax\n"
68 S = S + " .set k, k+1\n"
69 S = S + " .endr\n"
70# outro
71 S = S + " mov rax, [rdi + 8*" + str(plimbs+1) + "]\n"
72 S = S + " adc rax, 0\n"
73 S = S + " mov [rdi + 8*" + str(plimbs+1) + "], rax\n"
74 S = S + pop()
75 S = S + " ret\n"
76
77 return S
78
79
80def MontRedcAdd(plimbs):
81
82
83 # registers reserved rdi, rsi, rdx
84 # rax, rbx = rcx, r8
85 registers = ["r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"]
86 # if(plimbs > len(registers)):
87 # print("ERROR: Index out range")
88 # exit()
89 state = registers[:plimbs]
90 #state = registers
91
92 S = ""
93 # S = ".intel_syntax noprefix\n\n"
94 # S = S + ".section .rodata\n\n"
95 # S = S + ".section .text\n\n"
96
97 # S = S + ".macro p_times_w\n"
98 # S = S + "mult_"+ str(plimbs) + "x" + str(plimbs) + ":\n"
99
100 S = S + ".global a_plus_u_i\na_plus_u_i:\n"
101
102 S = S + push()
103
104 S = S + " lea rdx, [rdi + 0] ]\n"
105# intro
106 S = S + " mov rax, [rsi + 0]\n"
107 S = S + " add rax, [rdx + 0]\n"
108 S = S + " mov [rdi + 0], rax\n"
109# loop
110 S = S + " .set k, 1\n"
111 S = S + " .rept " + str(plimbs) + "\n"
112 S = S + " mov rax, [rsi + 8*k]\n"
113 S = S + " adc rax, [rdx + 8*k]\n"
114 S = S + " mov [rdi + 8*k], rax\n"
115 S = S + " .set k, k+1\n"
116 S = S + " .endr\n"
117# outro
118 S = S + " mov rax, [rdi + 8*" + str(plimbs+1) + "]\n"
119 S = S + " adc rax, 0\n"
120 S = S + " mov [rdi + 8*" + str(plimbs+1) + "], rax\n"
121 S = S + pop()
122 S = S + " ret\n"
123
124 return S
125
126def PrintMontLoop(plimbs):
127
128 S = ""
129
130 # S = S + "push rbp\n"
131 # S = S + "mov rbp, rsp\n"
132 # # allocate stack
133 # S = S + "sub rsp, " + str(plimbs*8) + "\n"
134 S = S + "lea rcx, [rsi]\n"
135 S = S + "mov r8, secsidh_internal_2047k221_p@GOTPCREL[rip]\n"
136
137
138 for k in range(0,plimbs-1):
139 S = S + "##########################\n"
140 S = S + "xor r11, r11\n\n"
141 S = S + "mov rdx, [rcx + 8*" + str(k) + "]\n"
142
143 for j in range(0,plimbs):
144 # high, low, p[i], A[i]
145 S = S + "mulx r9, r10, [r8 + 8*" + str(j) + "]\n"
146 # add low to "high - 1"
147 S = S + "adcx r10, r11\n"
148 # save new high
149 S = S + "mov r11, r9\n"
150 # S = S + "mov [rdi + 8*" + str(j) + "], r10\n\n"
151 #S = S + "mov [rsi + 8*" + str(j) + " ], r10\n\n"
152
153 S = S + "adox r10, [rcx + 8*" + str(j) + " + 8*" + str(k) + "]\n"
154 S = S + "mov [rcx + 8*" + str(j) + " + 8*" + str(k) + "], r10\n\n"
155 # S = S + "mov [rbp - 256 + 8*" + str(j) + " ], r10\n\n"
156
157 S = S + "adox r11, [rcx + 8*" + str(plimbs) + " + 8*" + str(k) + "]\n"
158 S = S + "mov [rcx + 8*" + str(plimbs) + " + 8*" + str(k) + "], r11\n"
159
160 S = S + "##########################\n"
161 S = S + "##########################\n"
162
163 S = S + "xor r11, r11\n\n"
164 S = S + "mov rdx, [rcx + 8*" + str(plimbs-1) + "]\n"
165 # high, low, p[i], A[i]
166 S = S + "mulx r9, r10, [r8]\n"
167 # add low to "high - 1"
168 S = S + "adcx r10, r11\n"
169 # save new high
170 S = S + "mov r11, r9\n"
171 # S = S + "mov [rdi + 8*" + str(j) + "], r10\n\n"
172 #S = S + "mov [rsi + 8*" + str(j) + " ], r10\n\n"
173 S = S + "adox r10, [rcx + 8*" + str(plimbs-1) + "]\n"
174
175 S = S + "##########################\n"
176
177 for j in range(1,plimbs):
178 # high, low, p[i], A[i]
179 S = S + "mulx r9, r10, [r8 + 8*" + str(j) + "]\n"
180 # add low to "high - 1"
181 S = S + "adcx r10, r11\n"
182 # save new high
183 S = S + "mov r11, r9\n"
184 # S = S + "mov [rdi + 8*" + str(j) + "], r10\n\n"
185 #S = S + "mov [rsi + 8*" + str(j) + " ], r10\n\n"
186
187 S = S + "adox r10, [rcx + 8*" + str(j) + " + 8*" + str(plimbs-1) + "]\n"
188 S = S + "mov [rdi + 8*" + str(j-1) + "], r10\n\n"
189 # S = S + "mov [rbp - 256 + 8*" + str(j) + " ], r10\n\n"
190
191 S = S + "adox r11, [rcx + 8*" + str(plimbs) + " + 8*" + str(plimbs-1) + "]\n"
192 S = S + "mov [rdi + 8*" + str(plimbs-1) + "], r11\n"
193
194
195
196 return S
197
198def PrintMult(plimbs):
199
200
201 # registers reserved rdi, rsi, rdx
202 # rax, rbx = rcx, r8
203 registers = ["r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"]
204 # if(plimbs > len(registers)):
205 # print("ERROR: Index out range")
206 # exit()
207 state = registers[:plimbs]
208 #state = registers
209
210 S = ""
211 # S = ".intel_syntax noprefix\n\n"
212 # S = S + ".section .rodata\n\n"
213 # S = S + ".section .text\n\n"
214
215 # S = S + ".macro p_times_w\n"
216 # S = S + "mult_"+ str(plimbs) + "x" + str(plimbs) + ":\n"
217
218 S = S + ".global p_times_w\np_times_w:\n"
219
220 S = S + push()
221
222 S = S + " mulx " + state[0] + ", rcx, [rsi + 0*8]\n"
223 S = S + " mov [rdi + 0*8], rcx\n"
224 for j in range(1,plimbs):
225 if(j == plimbs-1):
226 S = S + " mulx " + state[j%2] + ", rax, [rsi + " + str(j) + "*8]\n"
227 S = S + " adc " + state[(j+1)%2] + ", rax\n"
228 S = S + " adc " + state[j%2] + ", 0\n"
229 S = S + " mov [rdi + " + str(j) + "*8], " + state[(j+1)%2] + "\n"
230 S = S + " mov [rdi + " + str(j+1) + "*8], " + state[j%2] + "\n"
231 elif(j == 1):
232 S = S + " mulx " + state[j%2] + ", rax, [rsi + " + str(j) + "*8]\n"
233 S = S + " add " + state[(j+1)%2] + ", rax\n"
234 S = S + " mov [rdi + " + str(j) + "*8], " + state[(j+1)%2] + "\n"
235 else:
236 S = S + " mulx " + state[j%2] + ", rax, [rsi + " + str(j) + "*8]\n"
237 S = S + " adc " + state[(j+1)%2] + ", rax\n"
238 S = S + " mov [rdi + " + str(j) + "*8], " + state[(j+1)%2] + "\n"
239
240 S = S + pop()
241 return S
242
243#//+++++++++++++++ Main ++++++++++++++++++//
244
245def main(argv):
246 global printOut
247 if(len(argv) < 2):
248 print(
249 "\nplease specify the prime and type of execution [1 = only printing functions out, 0 = print full assembly]\n")
250 sys.exit()
251 else:
252 printOut = int(argv[1])
253 prime = argv[0]
254 length = 0
255
256 if(prime=="2047k221"):
257 length = 2048
258
259 elif(prime=="4095k256"):
260 length = 4096
261
262 elif(prime=="5119k234"):
263 length = 5120
264
265 elif(prime=="6143k256"):
266 length = 6144
267
268 elif(prime=="8191k332"):
269 length = 8192
270
271 elif(prime=="9215k384"):
272 length = 9216
273
274 else:
275 print("\nError : no prime available for this input\n")
276 sys.exit()
277
278
279 S = ".intel_syntax noprefix\n\n"
280 S = S + ".section .rodata\n\n"
281 S = S + ".section .text\n\n"
282
283 print(S)
284
285 plimbs = length//64
286
287 # S = PrintMult(plimbs)
288 # S = MontRedcAdd(plimbs)
289 S = PrintMontLoop(plimbs)
290
291 # S = ".global fp_mult_"+ str(plimbs) + "x" + str(plimbs) + "\n"
292 # S = S + "fp_mult_"+ str(plimbs) + "x" + str(plimbs) + ":\n"
293 # S = S + " mult_"+ str(plimbs) + "x" + str(plimbs) + "\n"
294 # S = S + " ret\n"
295
296 print(S)
297
298 print("\n")
299
300if __name__ == "__main__":
301 main(sys.argv[1:])
PrintMult(plimbs)
PrintMontLoop(plimbs)
MontRedcAdd(plimbs)
end if