optimizing a gate level bcm to the end of the earth and back
at main 1574 lines 71 kB view raw
1""" 2BCD to 7-segment decoder solver using SAT-based exact synthesis. 3 4This module implements a multi-output logic synthesis solver that minimizes 5gate inputs through shared term extraction and SAT/MaxSAT optimization. 6""" 7 8from dataclasses import dataclass, field 9from typing import Optional 10from pysat.formula import WCNF, CNF 11from pysat.examples.rc2 import RC2 12from pysat.solvers import Solver 13 14from .truth_tables import SEGMENT_NAMES, SEGMENT_MINTERMS, DONT_CARES 15from .quine_mccluskey import ( 16 Implicant, 17 quine_mccluskey_multi_output, 18 greedy_cover, 19) 20 21 22@dataclass 23class CostBreakdown: 24 """Detailed cost breakdown for a synthesis result.""" 25 26 and_inputs: int # Inputs to AND gates (multi-literal product terms only) 27 or_inputs: int # Inputs to OR gates (one per term per output) 28 num_and_gates: int # Number of AND gates (multi-literal terms) 29 num_or_gates: int # Number of OR gates (one per output = 7) 30 31 @property 32 def total(self) -> int: 33 """Total gate inputs (AND + OR).""" 34 return self.and_inputs + self.or_inputs 35 36 37@dataclass 38class GateInfo: 39 """Information about a gate in exact synthesis.""" 40 index: int # Gate index (0-based, after inputs) 41 input1: int # First input node index 42 input2: int # Second input node index 43 func: int # 4-bit function code 44 func_name: str # Human-readable function name 45 46 47@dataclass 48class SynthesisResult: 49 """Result of logic synthesis optimization.""" 50 51 cost: int # Total gate inputs (for backward compat, = cost_breakdown.and_inputs) 52 implicants_by_output: dict[str, list[Implicant]] 53 shared_implicants: list[tuple[Implicant, list[str]]] 54 method: str 55 expressions: dict[str, str] = field(default_factory=dict) 56 cost_breakdown: CostBreakdown = None 57 # For exact synthesis: gate-level circuit description 58 gates: list[GateInfo] = None 59 output_map: dict[str, int] = None # segment -> node index 60 61 62class BCDTo7SegmentSolver: 63 """ 64 Multi-output logic synthesis solver for BCD to 7-segment decoders. 65 66 Uses a combination of: 67 1. Quine-McCluskey with greedy cover for baseline 68 2. MaxSAT optimization for minimum-cost covering with sharing 69 3. SAT-based exact synthesis for provably optimal circuits 70 """ 71 72 def __init__(self): 73 self.prime_implicants: list[Implicant] = [] 74 self.minterms = {s: set(SEGMENT_MINTERMS[s]) for s in SEGMENT_NAMES} 75 self.dc_set = set(DONT_CARES) 76 77 def _compute_cost_breakdown( 78 self, 79 selected: list[Implicant], 80 implicants_by_output: dict[str, list[Implicant]] 81 ) -> CostBreakdown: 82 """ 83 Compute detailed cost breakdown for a set of selected implicants. 84 85 Cost model (assuming input complements are free): 86 - AND gate inputs: Only for multi-literal terms (2+ literals) 87 Single literals (A, B', etc.) are direct wires, no AND needed 88 - OR gate inputs: One per term per output it feeds 89 - AND gates: One per multi-literal term (shared across outputs) 90 - OR gates: One per output (7 total) 91 """ 92 and_inputs = 0 93 num_and_gates = 0 94 95 for impl in selected: 96 if impl.num_literals >= 2: 97 # Multi-literal term needs an AND gate 98 and_inputs += impl.num_literals 99 num_and_gates += 1 100 # Single-literal terms are just wires (no AND gate cost) 101 102 # OR inputs: count terms feeding each output 103 or_inputs = sum( 104 len(implicants_by_output[seg]) 105 for seg in SEGMENT_NAMES 106 if seg in implicants_by_output 107 ) 108 109 return CostBreakdown( 110 and_inputs=and_inputs, 111 or_inputs=or_inputs, 112 num_and_gates=num_and_gates, 113 num_or_gates=7, 114 ) 115 116 def greedy_baseline(self) -> SynthesisResult: 117 """ 118 Phase 1: Establish baseline using greedy set cover. 119 120 Returns the baseline cost and selected implicants. 121 """ 122 if not self.prime_implicants: 123 self.generate_prime_implicants() 124 125 selected, cost = greedy_cover(self.prime_implicants, self.minterms) 126 127 # Organize by output 128 implicants_by_output = {s: [] for s in SEGMENT_NAMES} 129 shared = [] 130 131 for impl in selected: 132 outputs_using = list(impl.covered_minterms.keys()) 133 if len(outputs_using) > 1: 134 shared.append((impl, outputs_using)) 135 for out in outputs_using: 136 implicants_by_output[out].append(impl) 137 138 # Build expressions 139 expressions = {} 140 for segment in SEGMENT_NAMES: 141 terms = [impl.to_expr_str() for impl in implicants_by_output[segment]] 142 expressions[segment] = " + ".join(terms) if terms else "0" 143 144 # Compute detailed cost breakdown 145 cost_breakdown = self._compute_cost_breakdown(selected, implicants_by_output) 146 147 return SynthesisResult( 148 cost=cost_breakdown.total, # Total = AND inputs + OR inputs 149 implicants_by_output=implicants_by_output, 150 shared_implicants=shared, 151 method="greedy", 152 expressions=expressions, 153 cost_breakdown=cost_breakdown, 154 ) 155 156 def generate_prime_implicants(self) -> list[Implicant]: 157 """Generate all prime implicants with multi-output coverage tags.""" 158 self.prime_implicants = quine_mccluskey_multi_output( 159 self.minterms, 160 self.dc_set, 161 n_vars=4 162 ) 163 return self.prime_implicants 164 165 def maxsat_optimize(self, target_cost: int = 22) -> SynthesisResult: 166 """ 167 Phase 2: MaxSAT optimization for minimum-cost covering with sharing. 168 169 Formulates the covering problem as weighted MaxSAT where: 170 - Hard clauses: every minterm of every output must be covered 171 - Soft clauses: minimize total literals (penalize each implicant) 172 """ 173 if not self.prime_implicants: 174 self.generate_prime_implicants() 175 176 wcnf = WCNF() 177 178 # Variable mapping: implicant index -> SAT variable (1-indexed) 179 impl_vars = {i: i + 1 for i in range(len(self.prime_implicants))} 180 181 # Hard constraints: every (output, minterm) pair must be covered 182 for segment in SEGMENT_NAMES: 183 for minterm in SEGMENT_MINTERMS[segment]: 184 covering = [] 185 for i, impl in enumerate(self.prime_implicants): 186 if segment in impl.covered_minterms: 187 if minterm in impl.covered_minterms[segment]: 188 covering.append(impl_vars[i]) 189 190 if covering: 191 wcnf.append(covering) # Hard: at least one must be selected 192 else: 193 raise RuntimeError( 194 f"No implicant covers {segment}:{minterm}" 195 ) 196 197 # Soft constraints: penalize each implicant by its total gate input cost 198 # Cost = AND inputs + OR inputs 199 # - AND inputs: num_literals if >= 2, else 0 (single literals are wires) 200 # - OR inputs: one per output this implicant covers 201 for i, impl in enumerate(self.prime_implicants): 202 and_cost = impl.num_literals if impl.num_literals >= 2 else 0 203 or_cost = len(impl.covered_minterms) # Number of outputs it feeds 204 total_cost = and_cost + or_cost 205 if total_cost > 0: 206 wcnf.append([-impl_vars[i]], weight=total_cost) 207 208 # Solve 209 with RC2(wcnf) as solver: 210 model = solver.compute() 211 if model is None: 212 raise RuntimeError("MaxSAT solver found no solution") 213 214 # Extract selected implicants 215 selected = [] 216 for i, impl in enumerate(self.prime_implicants): 217 if impl_vars[i] in model: 218 selected.append(impl) 219 220 # Organize by output 221 implicants_by_output = {s: [] for s in SEGMENT_NAMES} 222 shared = [] 223 224 for impl in selected: 225 outputs_using = list(impl.covered_minterms.keys()) 226 if len(outputs_using) > 1: 227 shared.append((impl, outputs_using)) 228 for out in outputs_using: 229 implicants_by_output[out].append(impl) 230 231 # Build expressions 232 expressions = {} 233 for segment in SEGMENT_NAMES: 234 terms = [impl.to_expr_str() for impl in implicants_by_output[segment]] 235 expressions[segment] = " + ".join(terms) if terms else "0" 236 237 # Compute detailed cost breakdown 238 cost_breakdown = self._compute_cost_breakdown(selected, implicants_by_output) 239 240 return SynthesisResult( 241 cost=cost_breakdown.total, # Total = AND inputs + OR inputs 242 implicants_by_output=implicants_by_output, 243 shared_implicants=shared, 244 method="maxsat", 245 expressions=expressions, 246 cost_breakdown=cost_breakdown, 247 ) 248 249 def exact_synthesis(self, max_gates: int = 15, min_gates: int = 1, use_complements: bool = False) -> SynthesisResult: 250 """ 251 Phase 3: SAT-based exact synthesis for provably optimal circuits. 252 253 Encodes the circuit synthesis problem as SAT and iteratively searches 254 for the minimum number of gates. 255 256 Args: 257 max_gates: Maximum number of gates to try 258 min_gates: Minimum number of gates to start from 259 use_complements: If True, include A',B',C',D' as free inputs 260 """ 261 import sys 262 complement_str = " (with complements)" if use_complements else "" 263 for num_gates in range(min_gates, max_gates + 1): 264 print(f" Trying {num_gates} gates{complement_str}...", flush=True) 265 sys.stdout.flush() 266 result = self._try_exact_synthesis(num_gates, use_complements) 267 if result is not None: 268 return result 269 270 raise RuntimeError(f"No solution found with up to {max_gates} gates") 271 272 def exact_synthesis_mixed(self, max_inputs: int = 24, use_complements: bool = True) -> SynthesisResult: 273 """ 274 SAT-based exact synthesis with mixed 2-input and 3-input gates. 275 276 Searches for circuits with total gate inputs <= max_inputs. 277 """ 278 import sys 279 280 # Try different combinations of 2-input and 3-input gates 281 # Cost = 2*n2 + 3*n3, want to minimize while finding valid circuit 282 best_result = None 283 284 for total_cost in range(14, max_inputs + 1): # Start from reasonable minimum 285 print(f" Trying circuits with {total_cost} total inputs...", flush=True) 286 287 # Try all valid (n2, n3) combinations for this cost 288 for n3 in range(total_cost // 3 + 1): 289 remaining = total_cost - 3 * n3 290 if remaining >= 0 and remaining % 2 == 0: 291 n2 = remaining // 2 292 if n2 + n3 >= 7: # Need at least 7 gates for 7 outputs 293 result = self._try_mixed_synthesis(n2, n3, use_complements) 294 if result is not None: 295 return result 296 297 raise RuntimeError(f"No solution found with up to {max_inputs} gate inputs") 298 299 def _try_mixed_synthesis(self, num_2input: int, num_3input: int, use_complements: bool = True, restrict_functions: bool = True) -> Optional[SynthesisResult]: 300 """Try synthesis with a specific mix of 2-input and 3-input gates.""" 301 n_primary = 4 302 n_inputs = 8 if use_complements else 4 303 n_outputs = 7 304 n_gates = num_2input + num_3input 305 n_nodes = n_inputs + n_gates 306 307 truth_rows = list(range(10)) 308 n_rows = len(truth_rows) 309 310 cnf = CNF() 311 var_counter = [1] 312 313 def new_var(): 314 v = var_counter[0] 315 var_counter[0] += 1 316 return v 317 318 # x[i][t] = output of node i on row t 319 x = {i: {t: new_var() for t in range(n_rows)} for i in range(n_nodes)} 320 321 # For 2-input gates: s2[i][j][k] = gate i uses inputs j, k 322 # For 3-input gates: s3[i][j][k][l] = gate i uses inputs j, k, l 323 s2 = {} 324 s3 = {} 325 f2 = {} # 4-bit function for 2-input gates 326 f3 = {} # 8-bit function for 3-input gates 327 328 # Gate type: is_3input[i] = True if gate i is 3-input 329 is_3input = {} 330 331 # First num_2input gates are 2-input, rest are 3-input 332 for gate_idx in range(n_gates): 333 i = n_inputs + gate_idx 334 if gate_idx < num_2input: 335 # 2-input gate 336 s2[i] = {} 337 for j in range(i): 338 s2[i][j] = {k: new_var() for k in range(j + 1, i)} 339 f2[i] = {p: {q: new_var() for q in range(2)} for p in range(2)} 340 else: 341 # 3-input gate 342 s3[i] = {} 343 for j in range(i): 344 s3[i][j] = {} 345 for k in range(j + 1, i): 346 s3[i][j][k] = {l: new_var() for l in range(k + 1, i)} 347 # 8-bit function table for 3 inputs 348 f3[i] = {p: {q: {r: new_var() for r in range(2)} for q in range(2)} for p in range(2)} 349 350 # g[h][i] = output h comes from node i 351 g = {h: {i: new_var() for i in range(n_nodes)} for h in range(n_outputs)} 352 353 # Constraint 1: Primary inputs fixed by truth table 354 for t_idx, t in enumerate(truth_rows): 355 for i in range(n_primary): 356 bit = (t >> (n_primary - 1 - i)) & 1 357 cnf.append([x[i][t_idx] if bit else -x[i][t_idx]]) 358 if use_complements: 359 for i in range(n_primary): 360 bit = (t >> (n_primary - 1 - i)) & 1 361 cnf.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]]) 362 363 # Constraint 2: Each gate has exactly one input selection 364 for gate_idx in range(n_gates): 365 i = n_inputs + gate_idx 366 if gate_idx < num_2input: 367 all_sels = [s2[i][j][k] for j in range(i) for k in range(j + 1, i)] 368 else: 369 all_sels = [s3[i][j][k][l] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i)] 370 371 cnf.append(all_sels) # At least one 372 for idx1, sel1 in enumerate(all_sels): 373 for sel2 in all_sels[idx1 + 1:]: 374 cnf.append([-sel1, -sel2]) # At most one 375 376 # Constraint 3: Gate function consistency 377 for gate_idx in range(n_gates): 378 i = n_inputs + gate_idx 379 if gate_idx < num_2input: 380 # 2-input gate 381 for j in range(i): 382 for k in range(j + 1, i): 383 for t_idx in range(n_rows): 384 for pv in range(2): 385 for qv in range(2): 386 for outv in range(2): 387 clause = [-s2[i][j][k]] 388 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 389 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 390 clause.append(-f2[i][pv][qv] if outv else f2[i][pv][qv]) 391 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 392 cnf.append(clause) 393 else: 394 # 3-input gate 395 for j in range(i): 396 for k in range(j + 1, i): 397 for l in range(k + 1, i): 398 for t_idx in range(n_rows): 399 for pv in range(2): 400 for qv in range(2): 401 for rv in range(2): 402 for outv in range(2): 403 clause = [-s3[i][j][k][l]] 404 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 405 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 406 clause.append(-x[l][t_idx] if rv else x[l][t_idx]) 407 clause.append(-f3[i][pv][qv][rv] if outv else f3[i][pv][qv][rv]) 408 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 409 cnf.append(clause) 410 411 # Constraint 3b: Restrict to standard gate functions 412 if restrict_functions: 413 # 2-input: AND, OR, XOR, XNOR, NAND, NOR 414 allowed_2input = [0b1000, 0b1110, 0b0110, 0b1001, 0b0111, 0b0001] 415 for gate_idx in range(num_2input): 416 i = n_inputs + gate_idx 417 or_clause = [] 418 for func in allowed_2input: 419 match_var = new_var() 420 or_clause.append(match_var) 421 for p in range(2): 422 for q in range(2): 423 bit_idx = p * 2 + q 424 expected = (func >> bit_idx) & 1 425 if expected: 426 cnf.append([-match_var, f2[i][p][q]]) 427 else: 428 cnf.append([-match_var, -f2[i][p][q]]) 429 cnf.append(or_clause) 430 431 # 3-input: AND3, OR3, XOR3, XNOR3, NAND3, NOR3 432 allowed_3input = [ 433 0b10000000, # AND3 434 0b11111110, # OR3 435 0b01111111, # NAND3 436 0b00000001, # NOR3 437 0b10010110, # XOR3 (odd parity) 438 0b01101001, # XNOR3 (even parity) 439 ] 440 for gate_idx in range(num_2input, num_2input + num_3input): 441 i = n_inputs + gate_idx 442 or_clause = [] 443 for func in allowed_3input: 444 match_var = new_var() 445 or_clause.append(match_var) 446 for p in range(2): 447 for q in range(2): 448 for r in range(2): 449 bit_idx = p * 4 + q * 2 + r 450 expected = (func >> bit_idx) & 1 451 if expected: 452 cnf.append([-match_var, f3[i][p][q][r]]) 453 else: 454 cnf.append([-match_var, -f3[i][p][q][r]]) 455 cnf.append(or_clause) 456 457 # Constraint 4: Each output assigned to exactly one node 458 for h in range(n_outputs): 459 cnf.append([g[h][i] for i in range(n_nodes)]) 460 for i in range(n_nodes): 461 for j in range(i + 1, n_nodes): 462 cnf.append([-g[h][i], -g[h][j]]) 463 464 # Constraint 5: Output correctness 465 for h, segment in enumerate(SEGMENT_NAMES): 466 for t_idx, t in enumerate(truth_rows): 467 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0 468 for i in range(n_nodes): 469 if expected: 470 cnf.append([-g[h][i], x[i][t_idx]]) 471 else: 472 cnf.append([-g[h][i], -x[i][t_idx]]) 473 474 # Solve 475 with Solver(bootstrap_with=cnf) as solver: 476 if solver.solve(): 477 model = set(solver.get_model()) 478 return self._decode_mixed_solution( 479 model, num_2input, num_3input, n_inputs, n_nodes, 480 x, s2, s3, f2, f3, g, use_complements 481 ) 482 return None 483 484 def _decode_mixed_solution(self, model, num_2input, num_3input, n_inputs, n_nodes, 485 x, s2, s3, f2, f3, g, use_complements) -> SynthesisResult: 486 """Decode SAT solution for mixed gate sizes.""" 487 def is_true(var): 488 return var in model 489 490 if use_complements: 491 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(num_2input + num_3input)] 492 else: 493 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(num_2input + num_3input)] 494 495 gates = [] 496 n_gates = num_2input + num_3input 497 498 for gate_idx in range(n_gates): 499 i = n_inputs + gate_idx 500 if gate_idx < num_2input: 501 # 2-input gate 502 for j in range(i): 503 for k in range(j + 1, i): 504 if is_true(s2[i][j][k]): 505 func = 0 506 for p in range(2): 507 for q in range(2): 508 if is_true(f2[i][p][q]): 509 func |= (1 << (p * 2 + q)) 510 func_name = self._decode_gate_function(func) 511 gates.append(GateInfo( 512 index=gate_idx, 513 input1=j, 514 input2=k, 515 func=func, 516 func_name=func_name, 517 )) 518 expr = f"({node_names[j]} {func_name} {node_names[k]})" 519 node_names[i] = expr 520 break 521 else: 522 # 3-input gate 523 for j in range(i): 524 for k in range(j + 1, i): 525 for l in range(k + 1, i): 526 if is_true(s3[i][j][k][l]): 527 func = 0 528 for p in range(2): 529 for q in range(2): 530 for r in range(2): 531 if is_true(f3[i][p][q][r]): 532 func |= (1 << (p * 4 + q * 2 + r)) 533 func_name = self._decode_3input_function(func) 534 # Store as GateInfo with input2 being a tuple indicator 535 gates.append(GateInfo( 536 index=gate_idx, 537 input1=j, 538 input2=(k, l), # Pack two inputs 539 func=func, 540 func_name=func_name, 541 )) 542 expr = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]})" 543 node_names[i] = expr 544 break 545 546 # Map outputs 547 output_map = {} 548 expressions = {} 549 for h, segment in enumerate(SEGMENT_NAMES): 550 for i in range(n_nodes): 551 if is_true(g[h][i]): 552 output_map[segment] = i 553 expressions[segment] = node_names[i] 554 break 555 556 total_cost = 2 * num_2input + 3 * num_3input 557 cost_breakdown = CostBreakdown( 558 and_inputs=total_cost, 559 or_inputs=0, 560 num_and_gates=num_2input + num_3input, 561 num_or_gates=0, 562 ) 563 564 return SynthesisResult( 565 cost=total_cost, 566 implicants_by_output={}, 567 shared_implicants=[], 568 method=f"exact_mixed_{num_2input}x2_{num_3input}x3", 569 expressions=expressions, 570 cost_breakdown=cost_breakdown, 571 gates=gates, 572 output_map=output_map, 573 ) 574 575 def _decode_3input_function(self, func: int) -> str: 576 """Decode 8-bit function for 3-input gate.""" 577 # Common 3-input functions 578 known = { 579 0b00000001: "NOR3", 580 0b01111111: "NAND3", 581 0b10000000: "AND3", 582 0b11111110: "OR3", 583 0b10010110: "XOR3", # Odd parity 584 0b01101001: "XNOR3", # Even parity 585 0b11101000: "MAJ", # Majority 586 0b00010111: "MIN", # Minority 587 } 588 return known.get(func, f"F3_{func:08b}") 589 590 def _decode_4input_function(self, func: int) -> str: 591 """Decode 16-bit function for 4-input gate.""" 592 known = { 593 0x0001: "NOR4", 594 0x7FFF: "NAND4", 595 0x8000: "AND4", 596 0xFFFE: "OR4", 597 0x6996: "XOR4", # Odd parity 598 0x9669: "XNOR4", # Even parity 599 } 600 return known.get(func, f"F4_{func:016b}") 601 602 def _build_general_cnf(self, num_2input: int, num_3input: int, num_4input: int, 603 use_complements: bool = True, restrict_functions: bool = True) -> Optional[dict]: 604 """Build CNF for general synthesis without solving. Returns CNF + metadata for decoding.""" 605 n_primary = 4 606 n_inputs = 8 if use_complements else 4 607 n_outputs = 7 608 n_gates = num_2input + num_3input + num_4input 609 n_nodes = n_inputs + n_gates 610 611 truth_rows = list(range(10)) 612 n_rows = len(truth_rows) 613 614 clauses = [] 615 var_counter = [1] 616 617 def new_var(): 618 v = var_counter[0] 619 var_counter[0] += 1 620 return v 621 622 # x[i][t] = output of node i on row t 623 x = {i: {t: new_var() for t in range(n_rows)} for i in range(n_nodes)} 624 625 # Selection and function variables 626 s2, s3, s4 = {}, {}, {} 627 f2, f3, f4 = {}, {}, {} 628 629 gate_sizes = [2] * num_2input + [3] * num_3input + [4] * num_4input 630 631 for gate_idx in range(n_gates): 632 i = n_inputs + gate_idx 633 size = gate_sizes[gate_idx] 634 635 if size == 2: 636 s2[i] = {} 637 for j in range(i): 638 s2[i][j] = {k: new_var() for k in range(j + 1, i)} 639 f2[i] = {p: {q: new_var() for q in range(2)} for p in range(2)} 640 elif size == 3: 641 s3[i] = {} 642 for j in range(i): 643 s3[i][j] = {} 644 for k in range(j + 1, i): 645 s3[i][j][k] = {l: new_var() for l in range(k + 1, i)} 646 f3[i] = {p: {q: {r: new_var() for r in range(2)} for q in range(2)} for p in range(2)} 647 else: 648 s4[i] = {} 649 for j in range(i): 650 s4[i][j] = {} 651 for k in range(j + 1, i): 652 s4[i][j][k] = {} 653 for l in range(k + 1, i): 654 s4[i][j][k][l] = {m: new_var() for m in range(l + 1, i)} 655 f4[i] = {p: {q: {r: {s: new_var() for s in range(2)} for r in range(2)} for q in range(2)} for p in range(2)} 656 657 g = {h: {i: new_var() for i in range(n_nodes)} for h in range(n_outputs)} 658 659 # Constraint 1: Primary inputs 660 for t_idx, t in enumerate(truth_rows): 661 for i in range(n_primary): 662 bit = (t >> (n_primary - 1 - i)) & 1 663 clauses.append([x[i][t_idx] if bit else -x[i][t_idx]]) 664 if use_complements: 665 for i in range(n_primary): 666 bit = (t >> (n_primary - 1 - i)) & 1 667 clauses.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]]) 668 669 # Constraint 2: Each gate has exactly one input selection 670 for gate_idx in range(n_gates): 671 i = n_inputs + gate_idx 672 size = gate_sizes[gate_idx] 673 674 if size == 2: 675 all_sels = [s2[i][j][k] for j in range(i) for k in range(j + 1, i)] 676 elif size == 3: 677 all_sels = [s3[i][j][k][l] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i)] 678 else: 679 all_sels = [s4[i][j][k][l][m] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i) for m in range(l + 1, i)] 680 681 if not all_sels: 682 return None 683 684 clauses.append(all_sels) 685 for idx1, sel1 in enumerate(all_sels): 686 for sel2 in all_sels[idx1 + 1:]: 687 clauses.append([-sel1, -sel2]) 688 689 # Constraint 2b: Symmetry breaking - gates of same type ordered by first input 690 # For consecutive gates of the same size, require first input index is non-decreasing 691 for gate_idx in range(n_gates - 1): 692 i = n_inputs + gate_idx 693 i_next = n_inputs + gate_idx + 1 694 size = gate_sizes[gate_idx] 695 size_next = gate_sizes[gate_idx + 1] 696 697 if size != size_next: 698 continue # Only break symmetry between same-type gates 699 700 if size == 2: 701 # For 2-input gates: if gate i has first input j and gate i+1 has first input j', 702 # require j <= j' 703 for j in range(i): 704 for k in range(j + 1, i): 705 for j_next in range(j): # j_next < j violates ordering 706 for k_next in range(j_next + 1, i_next): 707 if j_next in s2[i_next] and k_next in s2[i_next][j_next]: 708 clauses.append([-s2[i][j][k], -s2[i_next][j_next][k_next]]) 709 elif size == 3: 710 for j in range(i): 711 for k in range(j + 1, i): 712 for l in range(k + 1, i): 713 for j_next in range(j): 714 for k_next in range(j_next + 1, i_next): 715 for l_next in range(k_next + 1, i_next): 716 if j_next in s3[i_next] and k_next in s3[i_next][j_next] and l_next in s3[i_next][j_next][k_next]: 717 clauses.append([-s3[i][j][k][l], -s3[i_next][j_next][k_next][l_next]]) 718 else: # size == 4 719 for j in range(i): 720 for k in range(j + 1, i): 721 for l in range(k + 1, i): 722 for m in range(l + 1, i): 723 for j_next in range(j): 724 for k_next in range(j_next + 1, i_next): 725 for l_next in range(k_next + 1, i_next): 726 for m_next in range(l_next + 1, i_next): 727 if (j_next in s4[i_next] and k_next in s4[i_next][j_next] and 728 l_next in s4[i_next][j_next][k_next] and m_next in s4[i_next][j_next][k_next][l_next]): 729 clauses.append([-s4[i][j][k][l][m], -s4[i_next][j_next][k_next][l_next][m_next]]) 730 731 # Constraint 3: Gate function consistency 732 for gate_idx in range(n_gates): 733 i = n_inputs + gate_idx 734 size = gate_sizes[gate_idx] 735 736 if size == 2: 737 for j in range(i): 738 for k in range(j + 1, i): 739 for t_idx in range(n_rows): 740 for pv in range(2): 741 for qv in range(2): 742 for outv in range(2): 743 clause = [-s2[i][j][k]] 744 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 745 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 746 clause.append(-f2[i][pv][qv] if outv else f2[i][pv][qv]) 747 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 748 clauses.append(clause) 749 elif size == 3: 750 for j in range(i): 751 for k in range(j + 1, i): 752 for l in range(k + 1, i): 753 for t_idx in range(n_rows): 754 for pv in range(2): 755 for qv in range(2): 756 for rv in range(2): 757 for outv in range(2): 758 clause = [-s3[i][j][k][l]] 759 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 760 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 761 clause.append(-x[l][t_idx] if rv else x[l][t_idx]) 762 clause.append(-f3[i][pv][qv][rv] if outv else f3[i][pv][qv][rv]) 763 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 764 clauses.append(clause) 765 else: 766 for j in range(i): 767 for k in range(j + 1, i): 768 for l in range(k + 1, i): 769 for m in range(l + 1, i): 770 for t_idx in range(n_rows): 771 for pv in range(2): 772 for qv in range(2): 773 for rv in range(2): 774 for sv in range(2): 775 for outv in range(2): 776 clause = [-s4[i][j][k][l][m]] 777 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 778 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 779 clause.append(-x[l][t_idx] if rv else x[l][t_idx]) 780 clause.append(-x[m][t_idx] if sv else x[m][t_idx]) 781 clause.append(-f4[i][pv][qv][rv][sv] if outv else f4[i][pv][qv][rv][sv]) 782 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 783 clauses.append(clause) 784 785 # Constraint 3b: Restrict functions 786 if restrict_functions: 787 allowed_2input = [0b1000, 0b1110, 0b0110, 0b1001, 0b0111, 0b0001] 788 allowed_3input = [0b10000000, 0b11111110, 0b01111111, 0b00000001, 0b10010110, 0b01101001] 789 allowed_4input = [0x8000, 0xFFFE, 0x7FFF, 0x0001, 0x6996, 0x9669] 790 791 for gate_idx in range(n_gates): 792 i = n_inputs + gate_idx 793 size = gate_sizes[gate_idx] 794 795 if size == 2: 796 or_clause = [] 797 for func in allowed_2input: 798 match_var = new_var() 799 or_clause.append(match_var) 800 for p in range(2): 801 for q in range(2): 802 bit_idx = p * 2 + q 803 expected = (func >> bit_idx) & 1 804 clauses.append([-match_var, f2[i][p][q] if expected else -f2[i][p][q]]) 805 clauses.append(or_clause) 806 elif size == 3: 807 or_clause = [] 808 for func in allowed_3input: 809 match_var = new_var() 810 or_clause.append(match_var) 811 for p in range(2): 812 for q in range(2): 813 for r in range(2): 814 bit_idx = p * 4 + q * 2 + r 815 expected = (func >> bit_idx) & 1 816 clauses.append([-match_var, f3[i][p][q][r] if expected else -f3[i][p][q][r]]) 817 clauses.append(or_clause) 818 else: 819 or_clause = [] 820 for func in allowed_4input: 821 match_var = new_var() 822 or_clause.append(match_var) 823 for p in range(2): 824 for q in range(2): 825 for r in range(2): 826 for s in range(2): 827 bit_idx = p * 8 + q * 4 + r * 2 + s 828 expected = (func >> bit_idx) & 1 829 clauses.append([-match_var, f4[i][p][q][r][s] if expected else -f4[i][p][q][r][s]]) 830 clauses.append(or_clause) 831 832 # Constraint 4: Each output assigned to exactly one node 833 for h in range(n_outputs): 834 clauses.append([g[h][i] for i in range(n_nodes)]) 835 for i in range(n_nodes): 836 for j in range(i + 1, n_nodes): 837 clauses.append([-g[h][i], -g[h][j]]) 838 839 # Constraint 5: Output correctness 840 for h, segment in enumerate(SEGMENT_NAMES): 841 for t_idx, t in enumerate(truth_rows): 842 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0 843 for i in range(n_nodes): 844 clauses.append([-g[h][i], x[i][t_idx] if expected else -x[i][t_idx]]) 845 846 return { 847 'clauses': clauses, 848 'n_vars': var_counter[0] - 1, 849 'gate_sizes': gate_sizes, 850 'n_inputs': n_inputs, 851 'n_nodes': n_nodes, 852 'use_complements': use_complements, 853 'x': x, 's2': s2, 's3': s3, 's4': s4, 854 'f2': f2, 'f3': f3, 'f4': f4, 'g': g, 855 } 856 857 def _decode_general_solution_from_cnf(self, model: set, cnf_data: dict) -> SynthesisResult: 858 """Decode a SAT solution using stored CNF metadata.""" 859 def is_true(var): 860 return var in model 861 862 gate_sizes = cnf_data['gate_sizes'] 863 n_inputs = cnf_data['n_inputs'] 864 n_nodes = cnf_data['n_nodes'] 865 use_complements = cnf_data['use_complements'] 866 s2, s3, s4 = cnf_data['s2'], cnf_data['s3'], cnf_data['s4'] 867 f2, f3, f4 = cnf_data['f2'], cnf_data['f3'], cnf_data['f4'] 868 g = cnf_data['g'] 869 870 n_gates = len(gate_sizes) 871 if use_complements: 872 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(n_gates)] 873 else: 874 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(n_gates)] 875 876 gates = [] 877 total_cost = 0 878 879 for gate_idx in range(n_gates): 880 i = n_inputs + gate_idx 881 size = gate_sizes[gate_idx] 882 total_cost += size 883 884 if size == 2: 885 for j in range(i): 886 for k in range(j + 1, i): 887 if is_true(s2[i][j][k]): 888 func = sum((1 << (p * 2 + q)) for p in range(2) for q in range(2) if is_true(f2[i][p][q])) 889 func_name = self._decode_gate_function(func) 890 gates.append(GateInfo(index=gate_idx, input1=j, input2=k, func=func, func_name=func_name)) 891 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]})" 892 break 893 elif size == 3: 894 for j in range(i): 895 for k in range(j + 1, i): 896 for l in range(k + 1, i): 897 if is_true(s3[i][j][k][l]): 898 func = sum((1 << (p * 4 + q * 2 + r)) for p in range(2) for q in range(2) for r in range(2) if is_true(f3[i][p][q][r])) 899 func_name = self._decode_3input_function(func) 900 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l), func=func, func_name=func_name)) 901 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]})" 902 break 903 else: 904 for j in range(i): 905 for k in range(j + 1, i): 906 for l in range(k + 1, i): 907 for m in range(l + 1, i): 908 if is_true(s4[i][j][k][l][m]): 909 func = sum((1 << (p * 8 + q * 4 + r * 2 + s)) for p in range(2) for q in range(2) for r in range(2) for s in range(2) if is_true(f4[i][p][q][r][s])) 910 func_name = self._decode_4input_function(func) 911 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l, m), func=func, func_name=func_name)) 912 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]} {node_names[m]})" 913 break 914 915 output_map = {} 916 expressions = {} 917 for h, segment in enumerate(SEGMENT_NAMES): 918 for i in range(n_nodes): 919 if is_true(g[h][i]): 920 output_map[segment] = i 921 expressions[segment] = node_names[i] 922 break 923 924 num_2 = gate_sizes.count(2) 925 num_3 = gate_sizes.count(3) 926 num_4 = gate_sizes.count(4) 927 928 return SynthesisResult( 929 cost=total_cost, 930 implicants_by_output={}, 931 shared_implicants=[], 932 method=f"exact_general_{num_2}x2_{num_3}x3_{num_4}x4", 933 expressions=expressions, 934 cost_breakdown=CostBreakdown(and_inputs=total_cost, or_inputs=0, num_and_gates=n_gates, num_or_gates=0), 935 gates=gates, 936 output_map=output_map, 937 ) 938 939 def _try_general_synthesis(self, num_2input: int, num_3input: int, num_4input: int, 940 use_complements: bool = True, restrict_functions: bool = True) -> Optional[SynthesisResult]: 941 """Try synthesis with a mix of 2, 3, and 4-input gates.""" 942 n_primary = 4 943 n_inputs = 8 if use_complements else 4 944 n_outputs = 7 945 n_gates = num_2input + num_3input + num_4input 946 n_nodes = n_inputs + n_gates 947 948 truth_rows = list(range(10)) 949 n_rows = len(truth_rows) 950 951 cnf = CNF() 952 var_counter = [1] 953 954 def new_var(): 955 v = var_counter[0] 956 var_counter[0] += 1 957 return v 958 959 # x[i][t] = output of node i on row t 960 x = {i: {t: new_var() for t in range(n_rows)} for i in range(n_nodes)} 961 962 # Selection and function variables for each gate size 963 s2, s3, s4 = {}, {}, {} 964 f2, f3, f4 = {}, {}, {} 965 966 # Assign gate types: first num_2input are 2-input, then num_3input are 3-input, rest are 4-input 967 gate_sizes = [2] * num_2input + [3] * num_3input + [4] * num_4input 968 969 for gate_idx in range(n_gates): 970 i = n_inputs + gate_idx 971 size = gate_sizes[gate_idx] 972 973 if size == 2: 974 s2[i] = {} 975 for j in range(i): 976 s2[i][j] = {k: new_var() for k in range(j + 1, i)} 977 f2[i] = {p: {q: new_var() for q in range(2)} for p in range(2)} 978 elif size == 3: 979 s3[i] = {} 980 for j in range(i): 981 s3[i][j] = {} 982 for k in range(j + 1, i): 983 s3[i][j][k] = {l: new_var() for l in range(k + 1, i)} 984 f3[i] = {p: {q: {r: new_var() for r in range(2)} for q in range(2)} for p in range(2)} 985 else: # size == 4 986 s4[i] = {} 987 for j in range(i): 988 s4[i][j] = {} 989 for k in range(j + 1, i): 990 s4[i][j][k] = {} 991 for l in range(k + 1, i): 992 s4[i][j][k][l] = {m: new_var() for m in range(l + 1, i)} 993 f4[i] = {p: {q: {r: {s: new_var() for s in range(2)} for r in range(2)} for q in range(2)} for p in range(2)} 994 995 # g[h][i] = output h comes from node i 996 g = {h: {i: new_var() for i in range(n_nodes)} for h in range(n_outputs)} 997 998 # Constraint 1: Primary inputs fixed by truth table 999 for t_idx, t in enumerate(truth_rows): 1000 for i in range(n_primary): 1001 bit = (t >> (n_primary - 1 - i)) & 1 1002 cnf.append([x[i][t_idx] if bit else -x[i][t_idx]]) 1003 if use_complements: 1004 for i in range(n_primary): 1005 bit = (t >> (n_primary - 1 - i)) & 1 1006 cnf.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]]) 1007 1008 # Constraint 2: Each gate has exactly one input selection 1009 for gate_idx in range(n_gates): 1010 i = n_inputs + gate_idx 1011 size = gate_sizes[gate_idx] 1012 1013 if size == 2: 1014 all_sels = [s2[i][j][k] for j in range(i) for k in range(j + 1, i)] 1015 elif size == 3: 1016 all_sels = [s3[i][j][k][l] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i)] 1017 else: 1018 all_sels = [s4[i][j][k][l][m] for j in range(i) for k in range(j + 1, i) for l in range(k + 1, i) for m in range(l + 1, i)] 1019 1020 if not all_sels: 1021 return None # Not enough nodes for this gate size 1022 1023 cnf.append(all_sels) # At least one 1024 for idx1, sel1 in enumerate(all_sels): 1025 for sel2 in all_sels[idx1 + 1:]: 1026 cnf.append([-sel1, -sel2]) # At most one 1027 1028 # Constraint 3: Gate function consistency 1029 for gate_idx in range(n_gates): 1030 i = n_inputs + gate_idx 1031 size = gate_sizes[gate_idx] 1032 1033 if size == 2: 1034 for j in range(i): 1035 for k in range(j + 1, i): 1036 for t_idx in range(n_rows): 1037 for pv in range(2): 1038 for qv in range(2): 1039 for outv in range(2): 1040 clause = [-s2[i][j][k]] 1041 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 1042 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 1043 clause.append(-f2[i][pv][qv] if outv else f2[i][pv][qv]) 1044 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 1045 cnf.append(clause) 1046 elif size == 3: 1047 for j in range(i): 1048 for k in range(j + 1, i): 1049 for l in range(k + 1, i): 1050 for t_idx in range(n_rows): 1051 for pv in range(2): 1052 for qv in range(2): 1053 for rv in range(2): 1054 for outv in range(2): 1055 clause = [-s3[i][j][k][l]] 1056 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 1057 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 1058 clause.append(-x[l][t_idx] if rv else x[l][t_idx]) 1059 clause.append(-f3[i][pv][qv][rv] if outv else f3[i][pv][qv][rv]) 1060 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 1061 cnf.append(clause) 1062 else: # size == 4 1063 for j in range(i): 1064 for k in range(j + 1, i): 1065 for l in range(k + 1, i): 1066 for m in range(l + 1, i): 1067 for t_idx in range(n_rows): 1068 for pv in range(2): 1069 for qv in range(2): 1070 for rv in range(2): 1071 for sv in range(2): 1072 for outv in range(2): 1073 clause = [-s4[i][j][k][l][m]] 1074 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 1075 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 1076 clause.append(-x[l][t_idx] if rv else x[l][t_idx]) 1077 clause.append(-x[m][t_idx] if sv else x[m][t_idx]) 1078 clause.append(-f4[i][pv][qv][rv][sv] if outv else f4[i][pv][qv][rv][sv]) 1079 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 1080 cnf.append(clause) 1081 1082 # Constraint 3b: Restrict to standard gate functions 1083 if restrict_functions: 1084 # 2-input: AND, OR, XOR, XNOR, NAND, NOR 1085 allowed_2input = [0b1000, 0b1110, 0b0110, 0b1001, 0b0111, 0b0001] 1086 1087 # 3-input: AND3, OR3, XOR3, XNOR3, NAND3, NOR3 1088 allowed_3input = [0b10000000, 0b11111110, 0b01111111, 0b00000001, 0b10010110, 0b01101001] 1089 1090 # 4-input: AND4, OR4, XOR4, XNOR4, NAND4, NOR4 1091 allowed_4input = [0x8000, 0xFFFE, 0x7FFF, 0x0001, 0x6996, 0x9669] 1092 1093 for gate_idx in range(n_gates): 1094 i = n_inputs + gate_idx 1095 size = gate_sizes[gate_idx] 1096 1097 if size == 2: 1098 or_clause = [] 1099 for func in allowed_2input: 1100 match_var = new_var() 1101 or_clause.append(match_var) 1102 for p in range(2): 1103 for q in range(2): 1104 bit_idx = p * 2 + q 1105 expected = (func >> bit_idx) & 1 1106 if expected: 1107 cnf.append([-match_var, f2[i][p][q]]) 1108 else: 1109 cnf.append([-match_var, -f2[i][p][q]]) 1110 cnf.append(or_clause) 1111 elif size == 3: 1112 or_clause = [] 1113 for func in allowed_3input: 1114 match_var = new_var() 1115 or_clause.append(match_var) 1116 for p in range(2): 1117 for q in range(2): 1118 for r in range(2): 1119 bit_idx = p * 4 + q * 2 + r 1120 expected = (func >> bit_idx) & 1 1121 if expected: 1122 cnf.append([-match_var, f3[i][p][q][r]]) 1123 else: 1124 cnf.append([-match_var, -f3[i][p][q][r]]) 1125 cnf.append(or_clause) 1126 else: # size == 4 1127 or_clause = [] 1128 for func in allowed_4input: 1129 match_var = new_var() 1130 or_clause.append(match_var) 1131 for p in range(2): 1132 for q in range(2): 1133 for r in range(2): 1134 for s in range(2): 1135 bit_idx = p * 8 + q * 4 + r * 2 + s 1136 expected = (func >> bit_idx) & 1 1137 if expected: 1138 cnf.append([-match_var, f4[i][p][q][r][s]]) 1139 else: 1140 cnf.append([-match_var, -f4[i][p][q][r][s]]) 1141 cnf.append(or_clause) 1142 1143 # Constraint 4: Each output assigned to exactly one node 1144 for h in range(n_outputs): 1145 cnf.append([g[h][i] for i in range(n_nodes)]) 1146 for i in range(n_nodes): 1147 for j in range(i + 1, n_nodes): 1148 cnf.append([-g[h][i], -g[h][j]]) 1149 1150 # Constraint 5: Output correctness 1151 for h, segment in enumerate(SEGMENT_NAMES): 1152 for t_idx, t in enumerate(truth_rows): 1153 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0 1154 for i in range(n_nodes): 1155 if expected: 1156 cnf.append([-g[h][i], x[i][t_idx]]) 1157 else: 1158 cnf.append([-g[h][i], -x[i][t_idx]]) 1159 1160 # Solve 1161 with Solver(bootstrap_with=cnf) as solver: 1162 if solver.solve(): 1163 model = set(solver.get_model()) 1164 return self._decode_general_solution( 1165 model, gate_sizes, n_inputs, n_nodes, 1166 x, s2, s3, s4, f2, f3, f4, g, use_complements 1167 ) 1168 return None 1169 1170 def _decode_general_solution(self, model, gate_sizes, n_inputs, n_nodes, 1171 x, s2, s3, s4, f2, f3, f4, g, use_complements) -> SynthesisResult: 1172 """Decode SAT solution for general mixed gate sizes.""" 1173 def is_true(var): 1174 return var in model 1175 1176 n_gates = len(gate_sizes) 1177 if use_complements: 1178 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(n_gates)] 1179 else: 1180 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(n_gates)] 1181 1182 gates = [] 1183 total_cost = 0 1184 1185 for gate_idx in range(n_gates): 1186 i = n_inputs + gate_idx 1187 size = gate_sizes[gate_idx] 1188 total_cost += size 1189 1190 if size == 2: 1191 for j in range(i): 1192 for k in range(j + 1, i): 1193 if is_true(s2[i][j][k]): 1194 func = 0 1195 for p in range(2): 1196 for q in range(2): 1197 if is_true(f2[i][p][q]): 1198 func |= (1 << (p * 2 + q)) 1199 func_name = self._decode_gate_function(func) 1200 gates.append(GateInfo(index=gate_idx, input1=j, input2=k, func=func, func_name=func_name)) 1201 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]})" 1202 break 1203 elif size == 3: 1204 for j in range(i): 1205 for k in range(j + 1, i): 1206 for l in range(k + 1, i): 1207 if is_true(s3[i][j][k][l]): 1208 func = 0 1209 for p in range(2): 1210 for q in range(2): 1211 for r in range(2): 1212 if is_true(f3[i][p][q][r]): 1213 func |= (1 << (p * 4 + q * 2 + r)) 1214 func_name = self._decode_3input_function(func) 1215 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l), func=func, func_name=func_name)) 1216 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]})" 1217 break 1218 else: # size == 4 1219 for j in range(i): 1220 for k in range(j + 1, i): 1221 for l in range(k + 1, i): 1222 for m in range(l + 1, i): 1223 if is_true(s4[i][j][k][l][m]): 1224 func = 0 1225 for p in range(2): 1226 for q in range(2): 1227 for r in range(2): 1228 for s in range(2): 1229 if is_true(f4[i][p][q][r][s]): 1230 func |= (1 << (p * 8 + q * 4 + r * 2 + s)) 1231 func_name = self._decode_4input_function(func) 1232 gates.append(GateInfo(index=gate_idx, input1=j, input2=(k, l, m), func=func, func_name=func_name)) 1233 node_names[i] = f"({node_names[j]} {func_name} {node_names[k]} {node_names[l]} {node_names[m]})" 1234 break 1235 1236 # Map outputs 1237 output_map = {} 1238 expressions = {} 1239 for h, segment in enumerate(SEGMENT_NAMES): 1240 for i in range(n_nodes): 1241 if is_true(g[h][i]): 1242 output_map[segment] = i 1243 expressions[segment] = node_names[i] 1244 break 1245 1246 num_2 = gate_sizes.count(2) 1247 num_3 = gate_sizes.count(3) 1248 num_4 = gate_sizes.count(4) 1249 cost_breakdown = CostBreakdown( 1250 and_inputs=total_cost, 1251 or_inputs=0, 1252 num_and_gates=n_gates, 1253 num_or_gates=0, 1254 ) 1255 1256 return SynthesisResult( 1257 cost=total_cost, 1258 implicants_by_output={}, 1259 shared_implicants=[], 1260 method=f"exact_general_{num_2}x2_{num_3}x3_{num_4}x4", 1261 expressions=expressions, 1262 cost_breakdown=cost_breakdown, 1263 gates=gates, 1264 output_map=output_map, 1265 ) 1266 1267 def _try_exact_synthesis(self, num_gates: int, use_complements: bool = False, restrict_functions: bool = False) -> Optional[SynthesisResult]: 1268 """ 1269 Try to find a circuit with exactly num_gates gates. 1270 1271 Uses a SAT encoding where: 1272 - Variables encode gate structure (which inputs each gate uses) 1273 - Variables encode gate function (AND, OR, NAND, NOR, etc.) 1274 - Constraints ensure functional correctness on all valid inputs 1275 1276 Args: 1277 num_gates: Number of 2-input gates to use 1278 use_complements: If True, include A',B',C',D' as free inputs (8 total) 1279 restrict_functions: If True, only allow AND, OR, XOR, NAND, NOR, XNOR 1280 """ 1281 n_primary = 4 # A, B, C, D 1282 n_inputs = 8 if use_complements else 4 # Include complements if requested 1283 n_outputs = 7 # a, b, c, d, e, f, g 1284 n_nodes = n_inputs + num_gates 1285 1286 # Only verify on valid BCD inputs (0-9) 1287 truth_rows = list(range(10)) 1288 n_rows = len(truth_rows) 1289 1290 cnf = CNF() 1291 var_counter = [1] 1292 1293 def new_var(): 1294 v = var_counter[0] 1295 var_counter[0] += 1 1296 return v 1297 1298 # Variables: 1299 # x[i][t] = output of node i on row t 1300 # s[i][j][k] = gate i uses inputs j and k 1301 # f[i][p][q] = gate i output when inputs are (p, q) 1302 # g[h][i] = output h comes from node i 1303 1304 x = {} 1305 s = {} 1306 f = {} 1307 g = {} 1308 1309 for i in range(n_nodes): 1310 x[i] = {t: new_var() for t in range(n_rows)} 1311 1312 for i in range(n_inputs, n_nodes): 1313 s[i] = {} 1314 for j in range(i): 1315 s[i][j] = {k: new_var() for k in range(j + 1, i)} 1316 f[i] = {p: {q: new_var() for q in range(2)} for p in range(2)} 1317 1318 for h in range(n_outputs): 1319 g[h] = {i: new_var() for i in range(n_nodes)} 1320 1321 # Constraint 1: Primary inputs are fixed by truth table 1322 for t_idx, t in enumerate(truth_rows): 1323 # First 4 inputs: A, B, C, D 1324 for i in range(n_primary): 1325 bit = (t >> (n_primary - 1 - i)) & 1 1326 cnf.append([x[i][t_idx] if bit else -x[i][t_idx]]) 1327 # Next 4 inputs (if using complements): A', B', C', D' 1328 if use_complements: 1329 for i in range(n_primary): 1330 bit = (t >> (n_primary - 1 - i)) & 1 1331 # Complement is the inverse 1332 cnf.append([x[n_primary + i][t_idx] if not bit else -x[n_primary + i][t_idx]]) 1333 1334 # Constraint 2: Each gate has exactly one input pair 1335 for i in range(n_inputs, n_nodes): 1336 all_sels = [s[i][j][k] for j in range(i) for k in range(j + 1, i)] 1337 # At least one 1338 cnf.append(all_sels) 1339 # At most one 1340 for idx1, sel1 in enumerate(all_sels): 1341 for sel2 in all_sels[idx1 + 1:]: 1342 cnf.append([-sel1, -sel2]) 1343 1344 # Constraint 3: Gate function consistency 1345 for i in range(n_inputs, n_nodes): 1346 for j in range(i): 1347 for k in range(j + 1, i): 1348 for t_idx in range(n_rows): 1349 for pv in range(2): 1350 for qv in range(2): 1351 for outv in range(2): 1352 # If s[i][j][k] ∧ x[j][t]=pv ∧ x[k][t]=qv ∧ f[i][pv][qv]=outv 1353 # then x[i][t]=outv 1354 clause = [-s[i][j][k]] 1355 clause.append(-x[j][t_idx] if pv else x[j][t_idx]) 1356 clause.append(-x[k][t_idx] if qv else x[k][t_idx]) 1357 clause.append(-f[i][pv][qv] if outv else f[i][pv][qv]) 1358 clause.append(x[i][t_idx] if outv else -x[i][t_idx]) 1359 cnf.append(clause) 1360 1361 # Constraint 3b: Restrict to standard gate functions (if requested) 1362 # With complements available, we only need symmetric functions 1363 if restrict_functions: 1364 # Allowed: AND(1000), OR(1110), XOR(0110), NAND(0111), NOR(0001), XNOR(1001) 1365 allowed_funcs = [0b1000, 0b1110, 0b0110, 0b0111, 0b0001, 0b1001] 1366 for i in range(n_inputs, n_nodes): 1367 # For each gate, the function must be one of the allowed ones 1368 # Encode as: (func == AND) OR (func == OR) OR ... 1369 or_clause = [] 1370 for func in allowed_funcs: 1371 # Create aux var for "this gate has this function" 1372 match_var = new_var() 1373 or_clause.append(match_var) 1374 # match_var -> all f bits match the function 1375 for p in range(2): 1376 for q in range(2): 1377 bit_idx = p * 2 + q 1378 expected = (func >> bit_idx) & 1 1379 if expected: 1380 cnf.append([-match_var, f[i][p][q]]) 1381 else: 1382 cnf.append([-match_var, -f[i][p][q]]) 1383 # At least one match_var must be true 1384 cnf.append(or_clause) 1385 1386 # Constraint 4: Each output assigned to exactly one node 1387 for h in range(n_outputs): 1388 cnf.append([g[h][i] for i in range(n_nodes)]) 1389 for i in range(n_nodes): 1390 for j in range(i + 1, n_nodes): 1391 cnf.append([-g[h][i], -g[h][j]]) 1392 1393 # Constraint 5: Output correctness 1394 for h, segment in enumerate(SEGMENT_NAMES): 1395 for t_idx, t in enumerate(truth_rows): 1396 expected = 1 if t in SEGMENT_MINTERMS[segment] else 0 1397 for i in range(n_nodes): 1398 if expected: 1399 cnf.append([-g[h][i], x[i][t_idx]]) 1400 else: 1401 cnf.append([-g[h][i], -x[i][t_idx]]) 1402 1403 # Solve 1404 with Solver(bootstrap_with=cnf) as solver: 1405 if solver.solve(): 1406 model = set(solver.get_model()) 1407 return self._decode_exact_solution( 1408 model, num_gates, n_inputs, n_nodes, x, s, f, g, use_complements 1409 ) 1410 return None 1411 1412 def _decode_exact_solution( 1413 self, model, num_gates, n_inputs, n_nodes, x, s, f, g, use_complements: bool = False 1414 ) -> SynthesisResult: 1415 """Decode SAT solution into readable circuit description.""" 1416 1417 def is_true(var): 1418 return var in model 1419 1420 if use_complements: 1421 node_names = ['A', 'B', 'C', 'D', "A'", "B'", "C'", "D'"] + [f'g{i}' for i in range(num_gates)] 1422 else: 1423 node_names = ['A', 'B', 'C', 'D'] + [f'g{i}' for i in range(num_gates)] 1424 gates = [] 1425 1426 for i in range(n_inputs, n_nodes): 1427 for j in range(i): 1428 for k in range(j + 1, i): 1429 if is_true(s[i][j][k]): 1430 # Decode gate function 1431 func = 0 1432 for p in range(2): 1433 for q in range(2): 1434 if is_true(f[i][p][q]): 1435 func |= (1 << (p * 2 + q)) 1436 1437 func_name = self._decode_gate_function(func) 1438 gates.append(GateInfo( 1439 index=i - n_inputs, 1440 input1=j, 1441 input2=k, 1442 func=func, 1443 func_name=func_name, 1444 )) 1445 1446 # Build expression string 1447 expr = f"({node_names[j]} {func_name} {node_names[k]})" 1448 node_names[i] = expr 1449 break 1450 1451 # Map outputs to nodes 1452 output_map = {} 1453 expressions = {} 1454 for h, segment in enumerate(SEGMENT_NAMES): 1455 for i in range(n_nodes): 1456 if is_true(g[h][i]): 1457 output_map[segment] = i 1458 expressions[segment] = node_names[i] 1459 break 1460 1461 # For exact synthesis, all gates are 2-input gates 1462 cost_breakdown = CostBreakdown( 1463 and_inputs=num_gates * 2, 1464 or_inputs=0, 1465 num_and_gates=num_gates, 1466 num_or_gates=0, 1467 ) 1468 1469 return SynthesisResult( 1470 cost=num_gates * 2, 1471 implicants_by_output={}, 1472 shared_implicants=[], 1473 method=f"exact_{num_gates}gates", 1474 expressions=expressions, 1475 cost_breakdown=cost_breakdown, 1476 gates=gates, 1477 output_map=output_map, 1478 ) 1479 1480 def _decode_gate_function(self, func: int) -> str: 1481 """Decode 4-bit function to gate type name.""" 1482 # func encodes 2-input truth table: bit i = f(p,q) where i = p*2 + q 1483 # Bit 0: f(0,0), Bit 1: f(0,1), Bit 2: f(1,0), Bit 3: f(1,1) 1484 names = { 1485 0b0000: "0", # constant 0 1486 0b0001: "NOR", # 1 only when both inputs 0 1487 0b0010: "B>A", # B AND NOT A (inhibit) 1488 0b0011: "!A", # NOT first input 1489 0b0100: "A>B", # A AND NOT B (inhibit) 1490 0b0101: "!B", # NOT second input 1491 0b0110: "XOR", # exclusive or 1492 0b0111: "NAND", # NOT (A AND B) 1493 0b1000: "AND", # A AND B 1494 0b1001: "XNOR", # NOT (A XOR B) 1495 0b1010: "B", # pass through second input 1496 0b1011: "!A+B", # NOT A OR B (implication) 1497 0b1100: "A", # pass through first input 1498 0b1101: "A+!B", # A OR NOT B (implication) 1499 0b1110: "OR", # A OR B 1500 0b1111: "1", # constant 1 1501 } 1502 return names.get(func, f"F{func:04b}") 1503 1504 def solve(self, target_cost: int = 22, use_exact: bool = False) -> SynthesisResult: 1505 """ 1506 Run the complete optimization pipeline. 1507 1508 Args: 1509 target_cost: Target gate input count to beat 1510 use_exact: If True, use SAT-based exact synthesis (slower) 1511 1512 Returns: 1513 Best synthesis result found 1514 """ 1515 results = [] 1516 1517 # Phase 1: Generate primes and greedy baseline 1518 print("Phase 1: Generating prime implicants...") 1519 self.generate_prime_implicants() 1520 print(f" Found {len(self.prime_implicants)} prime implicants") 1521 1522 print("\nPhase 1b: Greedy set cover baseline...") 1523 greedy_result = self.greedy_baseline() 1524 results.append(greedy_result) 1525 print(f" Greedy cost: {greedy_result.cost} gate inputs") 1526 1527 # Phase 2: MaxSAT optimization 1528 print("\nPhase 2: MaxSAT optimization with sharing...") 1529 maxsat_result = self.maxsat_optimize(target_cost) 1530 results.append(maxsat_result) 1531 print(f" MaxSAT cost: {maxsat_result.cost} gate inputs") 1532 print(f" Shared terms: {len(maxsat_result.shared_implicants)}") 1533 1534 # Phase 3: Exact synthesis (optional) 1535 if use_exact: 1536 print("\nPhase 3: SAT-based exact synthesis...") 1537 try: 1538 exact_result = self.exact_synthesis(max_gates=12) 1539 results.append(exact_result) 1540 print(f" Exact cost: {exact_result.cost} gate inputs") 1541 except RuntimeError as e: 1542 print(f" Exact synthesis failed: {e}") 1543 1544 # Return best result 1545 best = min(results, key=lambda r: r.cost) 1546 print(f"\nBest result: {best.cost} gate inputs ({best.method})") 1547 1548 return best 1549 1550 def print_result(self, result: SynthesisResult): 1551 """Pretty-print a synthesis result.""" 1552 print(f"\n{'=' * 60}") 1553 print(f"Synthesis Result: {result.method}") 1554 print(f"{'=' * 60}") 1555 1556 if result.cost_breakdown: 1557 cb = result.cost_breakdown 1558 print(f"Cost breakdown:") 1559 print(f" AND gate inputs: {cb.and_inputs} ({cb.num_and_gates} gates)") 1560 print(f" OR gate inputs: {cb.or_inputs} (7 gates)") 1561 print(f" Total: {cb.total} gate inputs") 1562 else: 1563 print(f"Total gate inputs: {result.cost}") 1564 1565 if result.shared_implicants: 1566 print(f"\nShared terms ({len(result.shared_implicants)}):") 1567 for impl, outputs in result.shared_implicants: 1568 lit_info = f"({impl.num_literals} lit)" if impl.num_literals >= 2 else "(wire)" 1569 print(f" {impl.to_expr_str():12} {lit_info:8} -> {', '.join(outputs)}") 1570 1571 print("\nExpressions:") 1572 for segment in SEGMENT_NAMES: 1573 if segment in result.expressions: 1574 print(f" {segment} = {result.expressions[segment]}")